feat: allow configuring an engine context manager (#30266)
This commit is contained in:
parent
ee3a56714e
commit
710406aa76
|
|
@ -33,10 +33,11 @@ import os
|
|||
import re
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from importlib.resources import files
|
||||
from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict
|
||||
from typing import Any, Callable, Iterator, Literal, TYPE_CHECKING, TypedDict
|
||||
|
||||
import click
|
||||
import pkg_resources
|
||||
|
|
@ -1146,16 +1147,18 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
|
|||
# uploading CSVs will be stored.
|
||||
UPLOADED_CSV_HIVE_NAMESPACE: str | None = None
|
||||
|
||||
|
||||
# Function that computes the allowed schemas for the CSV uploads.
|
||||
# Allowed schemas will be a union of schemas_allowed_for_file_upload
|
||||
# db configuration and a result of this function.
|
||||
def allowed_schemas_for_csv_upload( # pylint: disable=unused-argument
|
||||
database: Database,
|
||||
user: models.User,
|
||||
) -> list[str]:
|
||||
return [UPLOADED_CSV_HIVE_NAMESPACE] if UPLOADED_CSV_HIVE_NAMESPACE else []
|
||||
|
||||
# mypy doesn't catch that if case ensures list content being always str
|
||||
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = ( # noqa: E731
|
||||
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
|
||||
if UPLOADED_CSV_HIVE_NAMESPACE
|
||||
else []
|
||||
)
|
||||
|
||||
ALLOWED_USER_CSV_SCHEMA_FUNC = allowed_schemas_for_csv_upload
|
||||
|
||||
# Values that should be treated as nulls for the csv uploads.
|
||||
CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
|
||||
|
|
@ -1266,6 +1269,21 @@ ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {}
|
|||
# The id of a template dashboard that should be copied to every new user
|
||||
DASHBOARD_TEMPLATE_ID = None
|
||||
|
||||
|
||||
# A context manager that wraps the call to `create_engine`. This can be used for many
|
||||
# things, such as chrooting to prevent 3rd party drivers to access the filesystem, or
|
||||
# setting up custom configuration for database drivers.
|
||||
@contextmanager
|
||||
def engine_context_manager( # pylint: disable=unused-argument
|
||||
database: Database,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
) -> Iterator[None]:
|
||||
yield None
|
||||
|
||||
|
||||
ENGINE_CONTEXT_MANAGER = engine_context_manager
|
||||
|
||||
# A callable that allows altering the database connection URL and params
|
||||
# on the fly, at runtime. This allows for things like impersonation or
|
||||
# arbitrary logic. For instance you can wire different users to
|
||||
|
|
|
|||
|
|
@ -418,38 +418,40 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
)
|
||||
|
||||
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
|
||||
engine_context = nullcontext()
|
||||
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(
|
||||
database_id=self.id
|
||||
)
|
||||
|
||||
if ssh_tunnel:
|
||||
# if ssh_tunnel is available build engine with information
|
||||
engine_context = ssh_manager_factory.instance.create_tunnel(
|
||||
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(self.id)
|
||||
ssh_context_manager = (
|
||||
ssh_manager_factory.instance.create_tunnel(
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
sqlalchemy_database_uri=sqlalchemy_uri,
|
||||
)
|
||||
if ssh_tunnel
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with engine_context as server_context:
|
||||
if ssh_tunnel and server_context:
|
||||
with ssh_context_manager as ssh_context:
|
||||
if ssh_context:
|
||||
logger.info(
|
||||
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s ssh_timeout at %s",
|
||||
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s "
|
||||
"ssh_timeout at %s",
|
||||
sshtunnel.TUNNEL_TIMEOUT,
|
||||
sshtunnel.SSH_TIMEOUT,
|
||||
server_context.local_bind_address,
|
||||
ssh_context.local_bind_address,
|
||||
)
|
||||
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
|
||||
sqlalchemy_uri,
|
||||
server_context,
|
||||
ssh_context,
|
||||
)
|
||||
|
||||
yield self._get_sqla_engine(
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
nullpool=nullpool,
|
||||
source=source,
|
||||
sqlalchemy_uri=sqlalchemy_uri,
|
||||
)
|
||||
engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
|
||||
with engine_context_manager(self, catalog, schema):
|
||||
yield self._get_sqla_engine(
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
nullpool=nullpool,
|
||||
source=source,
|
||||
sqlalchemy_uri=sqlalchemy_uri,
|
||||
)
|
||||
|
||||
def _get_sqla_engine( # pylint: disable=too-many-locals
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -233,10 +233,7 @@ def test_get_prequeries(mocker: MockerFixture) -> None:
|
|||
"""
|
||||
Tests for ``get_prequeries``.
|
||||
"""
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_sqla_engine",
|
||||
)
|
||||
mocker.patch.object(Database, "get_sqla_engine")
|
||||
db_engine_spec = mocker.patch.object(Database, "db_engine_spec")
|
||||
db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"]
|
||||
|
||||
|
|
@ -397,10 +394,7 @@ def test_get_sqla_engine(mocker: MockerFixture) -> None:
|
|||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="trino://",
|
||||
)
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
|
|
@ -556,3 +550,30 @@ def test_get_schema_access_for_file_upload() -> None:
|
|||
)
|
||||
|
||||
assert database.get_schema_access_for_file_upload() == {"public"}
|
||||
|
||||
|
||||
def test_engine_context_manager(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the engine context manager.
|
||||
"""
|
||||
engine_context_manager = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"superset.models.core.config",
|
||||
new={"ENGINE_CONTEXT_MANAGER": engine_context_manager},
|
||||
)
|
||||
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
with database.get_sqla_engine("catalog", "schema"):
|
||||
pass
|
||||
|
||||
engine_context_manager.assert_called_once_with(database, "catalog", "schema")
|
||||
engine_context_manager().__enter__.assert_called_once()
|
||||
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
|
||||
_get_sqla_engine.assert_called_once_with(
|
||||
catalog="catalog",
|
||||
schema="schema",
|
||||
nullpool=True,
|
||||
source=None,
|
||||
sqlalchemy_uri="trino://",
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue