feat: add endpoint to fetch available DBs (#14208)
* feat: add endpoint to fetch available DBs * Fix lint
This commit is contained in:
parent
ffcacc3393
commit
e7ad03d44f
|
|
@ -1064,6 +1064,18 @@ SQL_VALIDATORS_BY_ENGINE = {
|
|||
"postgresql": "PostgreSQLValidator",
|
||||
}
|
||||
|
||||
# A list of preferred databases, in order. These databases will be
|
||||
# displayed prominently in the "Add Database" dialog. You should
|
||||
# use the "engine" attribute of the corresponding DB engine spec in
|
||||
# `superset/db_engine_specs/`.
|
||||
PREFERRED_DATABASES: List[str] = [
|
||||
# "postgresql",
|
||||
# "presto",
|
||||
# "mysql",
|
||||
# "sqlite",
|
||||
# etc.
|
||||
]
|
||||
|
||||
# Do you want Talisman enabled?
|
||||
TALISMAN_ENABLED = False
|
||||
# If you want Talisman, how do you want it configured??
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import json
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from zipfile import ZipFile
|
||||
|
||||
from flask import g, request, Response, send_file
|
||||
|
|
@ -27,7 +27,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface
|
|||
from marshmallow import ValidationError
|
||||
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
|
||||
|
||||
from superset import event_logger
|
||||
from superset import app, event_logger
|
||||
from superset.commands.exceptions import CommandInvalidError
|
||||
from superset.commands.importers.v1.utils import get_contents_from_bundle
|
||||
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
|
||||
|
|
@ -63,6 +63,8 @@ from superset.databases.schemas import (
|
|||
TableMetadataResponseSchema,
|
||||
)
|
||||
from superset.databases.utils import get_table_metadata
|
||||
from superset.db_engine_specs import get_available_engine_specs
|
||||
from superset.db_engine_specs.base import BaseParametersMixin
|
||||
from superset.extensions import security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.typing import FlaskResponse
|
||||
|
|
@ -84,6 +86,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
"test_connection",
|
||||
"related_objects",
|
||||
"function_names",
|
||||
"available",
|
||||
}
|
||||
resource_name = "database"
|
||||
class_permission_name = "Database"
|
||||
|
|
@ -821,7 +824,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
schema:
|
||||
type: integer
|
||||
responses:
|
||||
200:
|
||||
200:
|
||||
description: Query result
|
||||
content:
|
||||
|
|
@ -839,3 +841,67 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
if not database:
|
||||
return self.response_404()
|
||||
return self.response(200, function_names=database.function_names,)
|
||||
|
||||
@expose("/available/", methods=["GET"])
|
||||
@protect()
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".available",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def available(self) -> Response:
|
||||
"""Return names of databases currently available
|
||||
---
|
||||
get:
|
||||
description:
|
||||
Get names of databases currently available
|
||||
responses:
|
||||
200:
|
||||
description: Database names
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Name of the database
|
||||
type: string
|
||||
preferred:
|
||||
description: Is the database preferred?
|
||||
type: bool
|
||||
sqlalchemy_uri_placeholder:
|
||||
description: Example placeholder for the SQLAlchemy URI
|
||||
type: string
|
||||
parameters:
|
||||
description: JSON schema defining the needed parameters
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", [])
|
||||
available_databases = []
|
||||
for engine_spec in get_available_engine_specs():
|
||||
payload: Dict[str, Any] = {
|
||||
"name": engine_spec.engine_name,
|
||||
"engine": engine_spec.engine,
|
||||
"preferred": engine_spec.engine in preferred_databases,
|
||||
}
|
||||
|
||||
if issubclass(engine_spec, BaseParametersMixin):
|
||||
payload["parameters"] = engine_spec.parameters_json_schema()
|
||||
payload[
|
||||
"sqlalchemy_uri_placeholder"
|
||||
] = engine_spec.sqlalchemy_uri_placeholder
|
||||
|
||||
available_databases.append(payload)
|
||||
|
||||
available_databases.sort(
|
||||
key=lambda payload: preferred_databases.index(payload["engine"])
|
||||
if payload["engine"] in preferred_databases
|
||||
else len(preferred_databases)
|
||||
)
|
||||
|
||||
return self.response(200, databases=available_databases)
|
||||
|
|
|
|||
|
|
@ -21,12 +21,14 @@ from typing import Any, Dict
|
|||
|
||||
from flask import current_app
|
||||
from flask_babel import lazy_gettext as _
|
||||
from marshmallow import fields, Schema, validates_schema
|
||||
from marshmallow import fields, pre_load, Schema, validates_schema
|
||||
from marshmallow.validate import Length, ValidationError
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import ArgumentError
|
||||
|
||||
from superset.db_engine_specs import get_engine_specs
|
||||
from superset.db_engine_specs.base import BaseParametersMixin
|
||||
from superset.exceptions import CertificateException, SupersetSecurityException
|
||||
from superset.models.core import PASSWORD_MASK
|
||||
from superset.security.analytics_db_safety import check_sqlalchemy_uri
|
||||
|
|
@ -207,7 +209,72 @@ def extra_validator(value: str) -> str:
|
|||
return value
|
||||
|
||||
|
||||
class DatabasePostSchema(Schema):
|
||||
class DatabaseParametersSchemaMixin:
|
||||
"""
|
||||
Allow SQLAlchemy URI to be passed as separate parameters.
|
||||
|
||||
This mixing is a first step in allowing the users to test, create and
|
||||
edit databases without having to know how to write a SQLAlchemy URI.
|
||||
Instead, each databases defines the parameters that it takes (eg,
|
||||
username, password, host, etc.) and the SQLAlchemy URI is built from
|
||||
these parameters.
|
||||
|
||||
When using this mixin make sure that `sqlalchemy_uri` is not required.
|
||||
"""
|
||||
|
||||
parameters = fields.Dict(
|
||||
keys=fields.Str(),
|
||||
values=fields.Raw(),
|
||||
description="DB-specific parameters for configuration",
|
||||
)
|
||||
|
||||
# pylint: disable=no-self-use, unused-argument
|
||||
@pre_load
|
||||
def build_sqlalchemy_uri(
|
||||
self, data: Dict[str, Any], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build SQLAlchemy URI from separate parameters.
|
||||
|
||||
This is used for databases that support being configured by individual
|
||||
parameters (eg, username, password, host, etc.), instead of requiring
|
||||
the constructed SQLAlchemy URI to be passed.
|
||||
"""
|
||||
parameters = data.pop("parameters", None)
|
||||
if parameters:
|
||||
if "engine" not in parameters:
|
||||
raise ValidationError(
|
||||
[
|
||||
_(
|
||||
"An engine must be specified when passing "
|
||||
"individual parameters to a database."
|
||||
)
|
||||
]
|
||||
)
|
||||
engine = parameters["engine"]
|
||||
|
||||
engine_specs = get_engine_specs()
|
||||
if engine not in engine_specs:
|
||||
raise ValidationError(
|
||||
[_('Engine "%(engine)s" is not a valid engine.', engine=engine,)]
|
||||
)
|
||||
engine_spec = engine_specs[engine]
|
||||
if not issubclass(engine_spec, BaseParametersMixin):
|
||||
raise ValidationError(
|
||||
[
|
||||
_(
|
||||
'Engine spec "%(engine_spec)s" does not support '
|
||||
"being configured via individual parameters.",
|
||||
engine_spec=engine_spec.__name__,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
data["sqlalchemy_uri"] = engine_spec.build_sqlalchemy_url(parameters)
|
||||
return data
|
||||
|
||||
|
||||
class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
|
||||
database_name = fields.String(
|
||||
description=database_name_description, required=True, validate=Length(1, 250),
|
||||
)
|
||||
|
|
@ -242,12 +309,11 @@ class DatabasePostSchema(Schema):
|
|||
)
|
||||
sqlalchemy_uri = fields.String(
|
||||
description=sqlalchemy_uri_description,
|
||||
required=True,
|
||||
validate=[Length(1, 1024), sqlalchemy_uri_validator],
|
||||
)
|
||||
|
||||
|
||||
class DatabasePutSchema(Schema):
|
||||
class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
|
||||
database_name = fields.String(
|
||||
description=database_name_description, allow_none=True, validate=Length(1, 250),
|
||||
)
|
||||
|
|
@ -282,12 +348,11 @@ class DatabasePutSchema(Schema):
|
|||
)
|
||||
sqlalchemy_uri = fields.String(
|
||||
description=sqlalchemy_uri_description,
|
||||
allow_none=True,
|
||||
validate=[Length(0, 1024), sqlalchemy_uri_validator],
|
||||
)
|
||||
|
||||
|
||||
class DatabaseTestConnectionSchema(Schema):
|
||||
class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
|
||||
database_name = fields.String(
|
||||
description=database_name_description, allow_none=True, validate=Length(1, 250),
|
||||
)
|
||||
|
|
@ -305,7 +370,6 @@ class DatabaseTestConnectionSchema(Schema):
|
|||
)
|
||||
sqlalchemy_uri = fields.String(
|
||||
description=sqlalchemy_uri_description,
|
||||
required=True,
|
||||
validate=[Length(1, 1024), sqlalchemy_uri_validator],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,8 +32,9 @@ import logging
|
|||
import pkgutil
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Type
|
||||
from typing import Any, Dict, List, Set, Type
|
||||
|
||||
import sqlalchemy.databases
|
||||
from pkg_resources import iter_entry_points
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
|
@ -67,7 +68,7 @@ def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
|
|||
try:
|
||||
engine_spec = ep.load()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.warning("Unable to load engine spec: %s", engine_spec)
|
||||
logger.warning("Unable to load Superset DB engine spec: %s", engine_spec)
|
||||
continue
|
||||
engine_specs.append(engine_spec)
|
||||
|
||||
|
|
@ -82,3 +83,23 @@ def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]:
|
|||
engine_specs_map[name] = engine_spec
|
||||
|
||||
return engine_specs_map
|
||||
|
||||
|
||||
def get_available_engine_specs() -> List[Type[BaseEngineSpec]]:
|
||||
# native SQLAlchemy dialects
|
||||
backends: Set[str] = {
|
||||
getattr(sqlalchemy.databases, attr).dialect.name
|
||||
for attr in sqlalchemy.databases.__all__
|
||||
}
|
||||
|
||||
# installed 3rd-party dialects
|
||||
for ep in iter_entry_points("sqlalchemy.dialects"):
|
||||
try:
|
||||
dialect = ep.load()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.warning("Unable to load SQLAlchemy dialect: %s", dialect)
|
||||
else:
|
||||
backends.add(dialect.name)
|
||||
|
||||
engine_specs = get_engine_specs()
|
||||
return [engine_specs[backend] for backend in backends if backend in engine_specs]
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from typing import (
|
|||
NamedTuple,
|
||||
Optional,
|
||||
Pattern,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
|
|
@ -38,18 +39,22 @@ from typing import (
|
|||
|
||||
import pandas as pd
|
||||
import sqlparse
|
||||
from apispec import APISpec
|
||||
from apispec.ext.marshmallow import MarshmallowPlugin
|
||||
from flask import g
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
from marshmallow import fields, Schema
|
||||
from sqlalchemy import column, DateTime, select, types
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.interfaces import Compiled, Dialect
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.engine.url import make_url, URL
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import quoted_name, text
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom
|
||||
from sqlalchemy.types import String, TypeEngine, UnicodeText
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from superset import app, security_manager, sql_parse
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
|
|
@ -150,7 +155,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
|
||||
engine = "base" # str as defined in sqlalchemy.engine.engine
|
||||
engine_aliases: Optional[Tuple[str]] = None
|
||||
engine_aliases: Set[str] = set()
|
||||
engine_name: Optional[
|
||||
str
|
||||
] = None # used for user messages, overridden in child classes
|
||||
|
|
@ -937,6 +942,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param cols: Columns to include in query
|
||||
:return: SQL query
|
||||
"""
|
||||
# pylint: disable=redefined-outer-name
|
||||
fields: Union[str, List[Any]] = "*"
|
||||
cols = cols or []
|
||||
if (show_cols or latest_partition) and not cols:
|
||||
|
|
@ -1293,3 +1299,90 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# schema for adding a database by providing parameters instead of the
|
||||
# full SQLAlchemy URI
|
||||
class BaseParametersSchema(Schema):
|
||||
username = fields.String(allow_none=True, description=__("Username"))
|
||||
password = fields.String(allow_none=True, description=__("Password"))
|
||||
host = fields.String(required=True, description=__("Hostname or IP address"))
|
||||
port = fields.Integer(required=True, description=__("Database port"))
|
||||
database = fields.String(required=True, description=__("Database name"))
|
||||
query = fields.Dict(
|
||||
keys=fields.Str(), values=fields.Raw(), description=__("Additinal parameters")
|
||||
)
|
||||
|
||||
|
||||
class BaseParametersType(TypedDict, total=False):
|
||||
username: Optional[str]
|
||||
password: Optional[str]
|
||||
host: str
|
||||
port: int
|
||||
database: str
|
||||
query: Dict[str, Any]
|
||||
|
||||
|
||||
class BaseParametersMixin:
|
||||
|
||||
"""
|
||||
Mixin for configuring DB engine specs via a dictionary.
|
||||
|
||||
With this mixin the SQLAlchemy engine can be configured through
|
||||
individual parameters, instead of the full SQLAlchemy URI. This
|
||||
mixin is for the most common pattern of URI:
|
||||
|
||||
drivername://user:password@host:port/dbname[?key=value&key=value...]
|
||||
|
||||
"""
|
||||
|
||||
# schema describing the parameters used to configure the DB
|
||||
parameters_schema = BaseParametersSchema()
|
||||
|
||||
# recommended driver name for the DB engine spec
|
||||
drivername = ""
|
||||
|
||||
# placeholder with the SQLAlchemy URI template
|
||||
sqlalchemy_uri_placeholder = (
|
||||
"drivername://user:password@host:port/dbname[?key=value&key=value...]"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_sqlalchemy_url(cls, parameters: BaseParametersType) -> str:
|
||||
return str(
|
||||
URL(
|
||||
cls.drivername,
|
||||
username=parameters.get("username"),
|
||||
password=parameters.get("password"),
|
||||
host=parameters["host"],
|
||||
port=parameters["port"],
|
||||
database=parameters["database"],
|
||||
query=parameters.get("query", {}),
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_parameters_from_uri(cls, uri: str) -> BaseParametersType:
|
||||
url = make_url(uri)
|
||||
return {
|
||||
"username": url.username,
|
||||
"password": url.password,
|
||||
"host": url.host,
|
||||
"port": url.port,
|
||||
"database": url.database,
|
||||
"query": url.query,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def parameters_json_schema(cls) -> Any:
|
||||
"""
|
||||
Return configuration parameters as OpenAPI.
|
||||
"""
|
||||
spec = APISpec(
|
||||
title="Database Parameters",
|
||||
version="1.0.0",
|
||||
openapi_version="3.0.2",
|
||||
plugins=[MarshmallowPlugin()],
|
||||
)
|
||||
spec.components.schema(cls.__name__, schema=cls.parameters_schema)
|
||||
return spec.to_dict()["components"]["schemas"][cls.__name__]
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
|
|||
from sqlalchemy.dialects.postgresql.base import PGInspector
|
||||
from sqlalchemy.types import String, TypeEngine
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, BaseParametersMixin
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.utils import core as utils
|
||||
|
|
@ -143,9 +143,15 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
|||
return "(timestamp 'epoch' + {col} * interval '1 second')"
|
||||
|
||||
|
||||
class PostgresEngineSpec(PostgresBaseEngineSpec):
|
||||
class PostgresEngineSpec(PostgresBaseEngineSpec, BaseParametersMixin):
|
||||
engine = "postgresql"
|
||||
engine_aliases = ("postgres",)
|
||||
engine_aliases = {"postgres"}
|
||||
|
||||
drivername = "postgresql+psycopg2"
|
||||
sqlalchemy_uri_placeholder = (
|
||||
"postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]"
|
||||
)
|
||||
|
||||
max_column_name_length = 63
|
||||
try_remove_schema_from_table_name = False
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ from sqlalchemy.sql import func
|
|||
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.errors import SupersetError
|
||||
from superset.models.core import Database
|
||||
from superset.models.reports import ReportSchedule, ReportScheduleType
|
||||
|
|
@ -613,7 +615,8 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
assert "can_read" in data["permissions"]
|
||||
assert "can_write" in data["permissions"]
|
||||
assert "can_function_names" in data["permissions"]
|
||||
assert len(data["permissions"]) == 3
|
||||
assert "can_available" in data["permissions"]
|
||||
assert len(data["permissions"]) == 4
|
||||
|
||||
def test_get_invalid_database_table_metadata(self):
|
||||
"""
|
||||
|
|
@ -1245,3 +1248,65 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
|
||||
assert rv.status_code == 200
|
||||
assert response == {"function_names": ["AVG", "MAX", "SUM"]}
|
||||
|
||||
@mock.patch("superset.databases.api.get_available_engine_specs")
|
||||
@mock.patch("superset.databases.api.app")
|
||||
def test_available(self, app, get_available_engine_specs):
|
||||
app.config = {"PREFERRED_DATABASES": ["postgresql"]}
|
||||
get_available_engine_specs.return_value = [
|
||||
MySQLEngineSpec,
|
||||
PostgresEngineSpec,
|
||||
]
|
||||
|
||||
self.login(username="admin")
|
||||
uri = "api/v1/database/available/"
|
||||
|
||||
rv = self.client.get(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
assert rv.status_code == 200
|
||||
assert response == {
|
||||
"databases": [
|
||||
{
|
||||
"engine": "postgresql",
|
||||
"name": "PostgreSQL",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"database": {
|
||||
"description": "Database name",
|
||||
"type": "string",
|
||||
},
|
||||
"host": {
|
||||
"description": "Hostname or IP address",
|
||||
"type": "string",
|
||||
},
|
||||
"password": {
|
||||
"description": "Password",
|
||||
"nullable": True,
|
||||
"type": "string",
|
||||
},
|
||||
"port": {
|
||||
"description": "Database port",
|
||||
"format": "int32",
|
||||
"type": "integer",
|
||||
},
|
||||
"query": {
|
||||
"additionalProperties": {},
|
||||
"description": "Additinal parameters",
|
||||
"type": "object",
|
||||
},
|
||||
"username": {
|
||||
"description": "Username",
|
||||
"nullable": True,
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["database", "host", "port"],
|
||||
"type": "object",
|
||||
},
|
||||
"preferred": True,
|
||||
"sqlalchemy_uri_placeholder": "postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]",
|
||||
},
|
||||
{"engine": "mysql", "name": "MySQL", "preferred": False},
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,125 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from marshmallow import fields, Schema, ValidationError
|
||||
|
||||
from superset.databases.schemas import DatabaseParametersSchemaMixin
|
||||
from superset.db_engine_specs.base import BaseParametersMixin
|
||||
|
||||
|
||||
class DummySchema(Schema, DatabaseParametersSchemaMixin):
|
||||
sqlalchemy_uri = fields.String()
|
||||
|
||||
|
||||
class DummyEngine(BaseParametersMixin):
|
||||
drivername = "dummy"
|
||||
|
||||
|
||||
class InvalidEngine:
|
||||
pass
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_mixin(get_engine_specs):
|
||||
get_engine_specs.return_value = {"dummy_engine": DummyEngine}
|
||||
payload = {
|
||||
"parameters": {
|
||||
"engine": "dummy_engine",
|
||||
"username": "username",
|
||||
"password": "password",
|
||||
"host": "localhost",
|
||||
"port": 12345,
|
||||
"database": "dbname",
|
||||
}
|
||||
}
|
||||
schema = DummySchema()
|
||||
result = schema.load(payload)
|
||||
assert result == {
|
||||
"sqlalchemy_uri": "dummy://username:password@localhost:12345/dbname"
|
||||
}
|
||||
|
||||
|
||||
def test_database_parameters_schema_mixin_no_engine():
|
||||
payload = {
|
||||
"parameters": {
|
||||
"username": "username",
|
||||
"password": "password",
|
||||
"host": "localhost",
|
||||
"port": 12345,
|
||||
"dbname": "dbname",
|
||||
}
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
except ValidationError as err:
|
||||
assert err.messages == {
|
||||
"_schema": [
|
||||
"An engine must be specified when passing individual parameters to a database."
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_mixin_invalid_engine(get_engine_specs):
|
||||
get_engine_specs.return_value = {}
|
||||
payload = {
|
||||
"parameters": {
|
||||
"engine": "dummy_engine",
|
||||
"username": "username",
|
||||
"password": "password",
|
||||
"host": "localhost",
|
||||
"port": 12345,
|
||||
"dbname": "dbname",
|
||||
}
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
except ValidationError as err:
|
||||
assert err.messages == {
|
||||
"_schema": ['Engine "dummy_engine" is not a valid engine.']
|
||||
}
|
||||
|
||||
|
||||
@mock.patch("superset.databases.schemas.get_engine_specs")
|
||||
def test_database_parameters_schema_no_mixin(get_engine_specs):
|
||||
get_engine_specs.return_value = {"invalid_engine": InvalidEngine}
|
||||
payload = {
|
||||
"parameters": {
|
||||
"engine": "invalid_engine",
|
||||
"username": "username",
|
||||
"password": "password",
|
||||
"host": "localhost",
|
||||
"port": 12345,
|
||||
"database": "dbname",
|
||||
}
|
||||
}
|
||||
schema = DummySchema()
|
||||
try:
|
||||
schema.load(payload)
|
||||
except ValidationError as err:
|
||||
assert err.messages == {
|
||||
"_schema": [
|
||||
(
|
||||
'Engine spec "InvalidEngine" does not support '
|
||||
"being configured via individual parameters."
|
||||
)
|
||||
]
|
||||
}
|
||||
|
|
@ -388,3 +388,44 @@ psql: error: could not connect to server: Operation timed out
|
|||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_base_parameters_mixin():
|
||||
parameters = {
|
||||
"username": "username",
|
||||
"password": "password",
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"database": "dbname",
|
||||
"query": {"foo": "bar"},
|
||||
}
|
||||
sqlalchemy_uri = PostgresEngineSpec.build_sqlalchemy_url(parameters)
|
||||
assert (
|
||||
sqlalchemy_uri
|
||||
== "postgresql+psycopg2://username:password@localhost:5432/dbname?foo=bar"
|
||||
)
|
||||
|
||||
parameters_from_uri = PostgresEngineSpec.get_parameters_from_uri(sqlalchemy_uri)
|
||||
assert parameters_from_uri == parameters
|
||||
|
||||
json_schema = PostgresEngineSpec.parameters_json_schema()
|
||||
assert json_schema == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"host": {"type": "string", "description": "Hostname or IP address"},
|
||||
"username": {"type": "string", "nullable": True, "description": "Username"},
|
||||
"password": {"type": "string", "nullable": True, "description": "Password"},
|
||||
"database": {"type": "string", "description": "Database name"},
|
||||
"query": {
|
||||
"type": "object",
|
||||
"description": "Additinal parameters",
|
||||
"additionalProperties": {},
|
||||
},
|
||||
"port": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"description": "Database port",
|
||||
},
|
||||
},
|
||||
"required": ["database", "host", "port"],
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue