feat: allow configuring an engine context manager (#30266)

This commit is contained in:
Beto Dealmeida 2024-09-23 12:36:18 -04:00 committed by GitHub
parent ee3a56714e
commit 710406aa76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 75 additions and 34 deletions

View File

@ -33,10 +33,11 @@ import os
import re
import sys
from collections import OrderedDict
from contextlib import contextmanager
from datetime import timedelta
from email.mime.multipart import MIMEMultipart
from importlib.resources import files
from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict
from typing import Any, Callable, Iterator, Literal, TYPE_CHECKING, TypedDict
import click
import pkg_resources
@ -1146,16 +1147,18 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
# uploading CSVs will be stored.
UPLOADED_CSV_HIVE_NAMESPACE: str | None = None
# Function that computes the allowed schemas for the CSV uploads.
# Allowed schemas will be a union of schemas_allowed_for_file_upload
# db configuration and a result of this function.
def allowed_schemas_for_csv_upload( # pylint: disable=unused-argument
database: Database,
user: models.User,
) -> list[str]:
return [UPLOADED_CSV_HIVE_NAMESPACE] if UPLOADED_CSV_HIVE_NAMESPACE else []
# mypy doesn't catch that if case ensures list content being always str
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = ( # noqa: E731
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
if UPLOADED_CSV_HIVE_NAMESPACE
else []
)
ALLOWED_USER_CSV_SCHEMA_FUNC = allowed_schemas_for_csv_upload
# Values that should be treated as nulls for the csv uploads.
CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
@ -1266,6 +1269,21 @@ ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {}
# The id of a template dashboard that should be copied to every new user
DASHBOARD_TEMPLATE_ID = None
# A context manager that wraps the call to `create_engine`. This can be used for many
# things, such as chrooting to prevent 3rd party drivers to access the filesystem, or
# setting up custom configuration for database drivers.
@contextmanager
def engine_context_manager( # pylint: disable=unused-argument
database: Database,
catalog: str | None,
schema: str | None,
) -> Iterator[None]:
yield None
ENGINE_CONTEXT_MANAGER = engine_context_manager
# A callable that allows altering the database connection URL and params
# on the fly, at runtime. This allows for things like impersonation or
# arbitrary logic. For instance you can wire different users to

View File

@ -418,38 +418,40 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
)
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
engine_context = nullcontext()
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(
database_id=self.id
)
if ssh_tunnel:
# if ssh_tunnel is available build engine with information
engine_context = ssh_manager_factory.instance.create_tunnel(
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(self.id)
ssh_context_manager = (
ssh_manager_factory.instance.create_tunnel(
ssh_tunnel=ssh_tunnel,
sqlalchemy_database_uri=sqlalchemy_uri,
)
if ssh_tunnel
else nullcontext()
)
with engine_context as server_context:
if ssh_tunnel and server_context:
with ssh_context_manager as ssh_context:
if ssh_context:
logger.info(
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s ssh_timeout at %s",
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s "
"ssh_timeout at %s",
sshtunnel.TUNNEL_TIMEOUT,
sshtunnel.SSH_TIMEOUT,
server_context.local_bind_address,
ssh_context.local_bind_address,
)
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
sqlalchemy_uri,
server_context,
ssh_context,
)
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
with engine_context_manager(self, catalog, schema):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
def _get_sqla_engine( # pylint: disable=too-many-locals
self,

View File

@ -233,10 +233,7 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Tests for ``get_prequeries``.
"""
mocker.patch.object(
Database,
"get_sqla_engine",
)
mocker.patch.object(Database, "get_sqla_engine")
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
@ -397,10 +394,7 @@ def test_get_sqla_engine(mocker: MockerFixture) -> None:
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
)
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
@ -556,3 +550,30 @@ def test_get_schema_access_for_file_upload() -> None:
)
assert database.get_schema_access_for_file_upload() == {"public"}
def test_engine_context_manager(mocker: MockerFixture) -> None:
"""
Test the engine context manager.
"""
engine_context_manager = mocker.MagicMock()
mocker.patch(
"superset.models.core.config",
new={"ENGINE_CONTEXT_MANAGER": engine_context_manager},
)
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
with database.get_sqla_engine("catalog", "schema"):
pass
engine_context_manager.assert_called_once_with(database, "catalog", "schema")
engine_context_manager().__enter__.assert_called_once()
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
_get_sqla_engine.assert_called_once_with(
catalog="catalog",
schema="schema",
nullpool=True,
source=None,
sqlalchemy_uri="trino://",
)