feat(databases): test connection api (#10723)

* test connection api on databases

* update test connection tests

* update database api test and open api description

* moved test connection to commands

* update error message

* fix isort

* fix mypy

* fix black

* fix mypy pre commit
This commit is contained in:
Lily Kuang 2020-09-09 13:37:48 -07:00 committed by GitHub
parent 9a59bdda48
commit 8a3ac70c06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 316 additions and 23 deletions

View File

@ -149,8 +149,8 @@ class SupersetAppInitializer:
AlertLogModelView,
AlertModelView,
AlertObservationModelView,
ValidatorInlineView,
SQLObserverInlineView,
ValidatorInlineView,
)
from superset.views.annotations import (
AnnotationLayerModelView,

View File

@ -20,8 +20,15 @@ from typing import Any, Optional
from flask import g, request, Response
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import gettext as _
from marshmallow import ValidationError
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
NoSuchModuleError,
NoSuchTableError,
OperationalError,
SQLAlchemyError,
)
from superset import event_logger
from superset.constants import RouteMethod
@ -33,8 +40,10 @@ from superset.databases.commands.exceptions import (
DatabaseDeleteFailedError,
DatabaseInvalidError,
DatabaseNotFoundError,
DatabaseSecurityUnsafeError,
DatabaseUpdateFailedError,
)
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.commands.update import UpdateDatabaseCommand
from superset.databases.dao import DatabaseDAO
from superset.databases.decorators import check_datasource_access
@ -44,6 +53,7 @@ from superset.databases.schemas import (
DatabasePostSchema,
DatabasePutSchema,
DatabaseRelatedObjectsResponse,
DatabaseTestConnectionSchema,
SchemasResponseSchema,
SelectStarResponseSchema,
TableMetadataResponseSchema,
@ -65,6 +75,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"table_metadata",
"select_star",
"schemas",
"test_connection",
"related_objects",
}
class_permission_name = "DatabaseView"
@ -343,7 +354,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
@rison(database_schemas_query_schema)
@statsd_metrics
def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse:
""" Get all schemas from a database
"""Get all schemas from a database
---
get:
description: Get all schemas from a database
@ -400,7 +411,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
def table_metadata(
self, database: Database, table_name: str, schema_name: str
) -> FlaskResponse:
""" Table schema info
"""Table schema info
---
get:
description: Get database table metadata
@ -457,7 +468,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
def select_star(
self, database: Database, table_name: str, schema_name: Optional[str] = None
) -> FlaskResponse:
""" Table schema info
"""Table schema info
---
get:
description: Get database select star for table
@ -506,6 +517,86 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
self.incr_stats("success", self.select_star.__name__)
return self.response(200, result=result)
@expose("/test_connection", methods=["POST"])
@protect()
@safe
@event_logger.log_this
@statsd_metrics
def test_connection( # pylint: disable=too-many-return-statements
self,
) -> FlaskResponse:
"""Tests a database connection
---
post:
description: >-
Tests a database connection
requestBody:
description: Database schema
required: true
content:
application/json:
schema:
type: object
properties:
encrypted_extra:
type: object
extras:
type: object
name:
type: string
server_cert:
type: string
responses:
200:
description: Database Test Connection
content:
application/json:
schema:
type: object
properties:
message:
type: string
400:
$ref: '#/components/responses/400'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
item = DatabaseTestConnectionSchema().load(request.json)
# This validates custom Schema with custom validations
except ValidationError as error:
return self.response_400(message=error.messages)
try:
TestConnectionDatabaseCommand(g.user, item).run()
return self.response(200, message="OK")
except (NoSuchModuleError, ModuleNotFoundError):
logger.info("Invalid driver")
driver_name = make_url(item.get("sqlalchemy_uri")).drivername
return self.response(
400,
message=_(f"Could not load database driver: {driver_name}"),
driver_name=driver_name,
)
except DatabaseSecurityUnsafeError as ex:
return self.response_422(message=ex)
except OperationalError:
logger.warning("Connection failed")
return self.response(
500,
message=_("Connection failed, please check your connection settings"),
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Unexpected error %s", type(ex).__name__)
return self.response_400(
message=_(
"Unexpected error occurred, please check your logs for details"
)
)
@expose("/<int:pk>/related_objects/", methods=["GET"])
@protect()
@safe

View File

@ -24,6 +24,7 @@ from superset.commands.exceptions import (
DeleteFailedError,
UpdateFailedError,
)
from superset.security.analytics_db_safety import DBSecurityException
class DatabaseInvalidError(CommandInvalidError):
@ -109,3 +110,7 @@ class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
class DatabaseDeleteFailedError(DeleteFailedError):
message = _("Database could not be deleted.")
class DatabaseSecurityUnsafeError(DBSecurityException):
message = _("Stopped an unsafe database connection")

View File

@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from contextlib import closing
from typing import Any, Dict, Optional
import simplejson as json
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import select
from superset.commands.base import BaseCommand
from superset.databases.commands.exceptions import DatabaseSecurityUnsafeError
from superset.databases.dao import DatabaseDAO
from superset.models.core import Database
from superset.security.analytics_db_safety import DBSecurityException
logger = logging.getLogger(__name__)
class TestConnectionDatabaseCommand(BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()
self._model: Optional[Database] = None
def run(self) -> None:
self.validate()
try:
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted
database = DatabaseDAO.build_db_for_connection_test(
server_cert=self._properties.get("server_cert", ""),
extra=json.dumps(self._properties.get("extra", {})),
impersonate_user=self._properties.get("impersonate_user", False),
encrypted_extra=json.dumps(self._properties.get("encrypted_extra", {})),
)
if database is not None:
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = self._actor.username if self._actor is not None else None
engine = database.get_sqla_engine(user_name=username)
with closing(engine.connect()) as conn:
conn.scalar(select([1]))
except DBSecurityException as ex:
logger.warning(ex)
raise DatabaseSecurityUnsafeError()
def validate(self) -> None:
database_name = self._properties.get("database_name")
if database_name is not None:
self._model = DatabaseDAO.get_database_by_name(database_name)

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional
from superset.dao.base import BaseDAO
from superset.databases.filters import DatabaseFilter
@ -45,6 +45,25 @@ class DatabaseDAO(BaseDAO):
)
return not db.session.query(database_query.exists()).scalar()
@staticmethod
def get_database_by_name(database_name: str) -> Optional[Database]:
return (
db.session.query(Database)
.filter(Database.database_name == database_name)
.one_or_none()
)
@staticmethod
def build_db_for_connection_test(
server_cert: str, extra: str, impersonate_user: bool, encrypted_extra: str
) -> Optional[Database]:
return Database(
server_cert=server_cert,
extra=extra,
impersonate_user=impersonate_user,
encrypted_extra=encrypted_extra,
)
@classmethod
def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
datasets = cls.find_by_id(database_id).tables

View File

@ -17,6 +17,7 @@
import inspect
import json
from flask import current_app
from flask_babel import lazy_gettext as _
from marshmallow import fields, Schema
from marshmallow.validate import Length, ValidationError
@ -24,7 +25,6 @@ from sqlalchemy import MetaData
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import ArgumentError
from superset import app
from superset.exceptions import CertificateException
from superset.utils.core import markdown, parse_ssl_cert
@ -142,7 +142,7 @@ def sqlalchemy_uri_validator(value: str) -> str:
)
]
)
if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] and value:
if current_app.config.get("PREVENT_UNSAFE_DB_CONNECTIONS", True) and value:
if value.startswith("sqlite"):
raise ValidationError(
[
@ -291,6 +291,25 @@ class DatabasePutSchema(Schema):
)
class DatabaseTestConnectionSchema(Schema):
database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250),
)
impersonate_user = fields.Boolean(description=impersonate_user_description)
extra = fields.String(description=extra_description, validate=extra_validator)
encrypted_extra = fields.String(
description=encrypted_extra_description, validate=encrypted_extra_validator
)
server_cert = fields.String(
description=server_cert_description, validate=server_cert_validator
)
sqlalchemy_uri = fields.String(
description=sqlalchemy_uri_description,
required=True,
validate=[Length(1, 1024), sqlalchemy_uri_validator],
)
class TableMetadataOptionsResponseSchema(Schema):
deferrable = fields.Bool()
initially = fields.Bool()

View File

@ -70,8 +70,8 @@ from superset.utils.urls import get_url_path
if TYPE_CHECKING:
# pylint: disable=unused-import
from werkzeug.datastructures import TypeConversionDict
from flask_appbuilder.security.sqla.models import User
from werkzeug.datastructures import TypeConversionDict
# Globals
config = app.config

View File

@ -91,24 +91,25 @@ class BaseSupersetModelRestApi(ModelRestApi):
csrf_exempt = False
method_permission_name = {
"get_list": "list",
"get": "show",
"bulk_delete": "delete",
"data": "list",
"delete": "delete",
"distinct": "list",
"export": "mulexport",
"get": "show",
"get_list": "list",
"info": "list",
"post": "add",
"put": "edit",
"delete": "delete",
"bulk_delete": "delete",
"info": "list",
"related": "list",
"distinct": "list",
"thumbnail": "list",
"refresh": "edit",
"data": "list",
"viz_types": "list",
"related": "list",
"related_objects": "list",
"table_metadata": "list",
"select_star": "list",
"schemas": "list",
"select_star": "list",
"table_metadata": "list",
"test_connection": "post",
"thumbnail": "list",
"viz_types": "list",
}
order_rel_fields: Dict[str, Tuple[str, str]] = {}

View File

@ -1162,7 +1162,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
logger.warning("Stopped an unsafe database connection")
return json_error_response(_(str(ex)), 400)
except Exception as ex: # pylint: disable=broad-except
logger.error("Unexpected error %s", type(ex).__name__)
logger.warning("Unexpected error %s", type(ex).__name__)
return json_error_response(
_("Unexpected error occurred, please check your logs for details"), 400
)

View File

@ -21,13 +21,13 @@ import json
import prison
from sqlalchemy.sql import func
import tests.test_app
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database
from tests.base_tests import SupersetTestCase
from tests.fixtures.certificates import ssl_certificate
from tests.test_app import app
class TestDatabaseApi(SupersetTestCase):
@ -652,6 +652,97 @@ class TestDatabaseApi(SupersetTestCase):
)
self.assertEqual(rv.status_code, 400)
def test_test_connection(self):
"""
Database API: Test test connection
"""
# need to temporarily allow sqlite dbs, teardown will undo this
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
self.login("admin")
example_db = get_example_database()
# validate that the endpoint works with the password-masked sqlalchemy uri
data = {
"sqlalchemy_uri": example_db.safe_sqlalchemy_uri(),
"database_name": "examples",
"impersonate_user": False,
}
url = f"api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
# validate that the endpoint works with the decrypted sqlalchemy uri
data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"database_name": "examples",
"impersonate_user": False,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
def test_test_connection_failed(self):
"""
Database API: Test test connection failed
"""
self.login("admin")
data = {
"sqlalchemy_uri": "broken://url",
"database_name": "examples",
"impersonate_user": False,
}
url = f"api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "broken",
"message": "Could not load database driver: broken",
}
self.assertEqual(response, expected_response)
data = {
"sqlalchemy_uri": "mssql+pymssql://url",
"database_name": "examples",
"impersonate_user": False,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "mssql+pymssql",
"message": "Could not load database driver: mssql+pymssql",
}
self.assertEqual(response, expected_response)
def test_test_connection_unsafe_uri(self):
"""
Database API: Test test connection with unsafe uri
"""
self.login("admin")
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
data = {
"sqlalchemy_uri": "sqlite:///home/superset/unsafe.db",
"database_name": "unsafe",
"impersonate_user": False,
}
url = f"api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"sqlalchemy_uri": [
"SQLite database cannot be used as a data source for security reasons."
]
}
}
self.assertEqual(response, expected_response)
def test_get_database_related_objects(self):
"""
Database API: Test get chart and dashboard count related to a database