diff --git a/superset/config.py b/superset/config.py index 79b3f2de9..42a46012e 100644 --- a/superset/config.py +++ b/superset/config.py @@ -33,10 +33,11 @@ import os import re import sys from collections import OrderedDict +from contextlib import contextmanager from datetime import timedelta from email.mime.multipart import MIMEMultipart from importlib.resources import files -from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict +from typing import Any, Callable, Iterator, Literal, TYPE_CHECKING, TypedDict import click import pkg_resources @@ -1146,16 +1147,18 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # uploading CSVs will be stored. UPLOADED_CSV_HIVE_NAMESPACE: str | None = None + # Function that computes the allowed schemas for the CSV uploads. # Allowed schemas will be a union of schemas_allowed_for_file_upload # db configuration and a result of this function. +def allowed_schemas_for_csv_upload( # pylint: disable=unused-argument + database: Database, + user: models.User, +) -> list[str]: + return [UPLOADED_CSV_HIVE_NAMESPACE] if UPLOADED_CSV_HIVE_NAMESPACE else [] -# mypy doesn't catch that if case ensures list content being always str -ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = ( # noqa: E731 - lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE] - if UPLOADED_CSV_HIVE_NAMESPACE - else [] -) + +ALLOWED_USER_CSV_SCHEMA_FUNC = allowed_schemas_for_csv_upload # Values that should be treated as nulls for the csv uploads. CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES) @@ -1266,6 +1269,21 @@ ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {} # The id of a template dashboard that should be copied to every new user DASHBOARD_TEMPLATE_ID = None + +# A context manager that wraps the call to `create_engine`. This can be used for many +# things, such as chrooting to prevent 3rd party drivers to access the filesystem, or +# setting up custom configuration for database drivers. +@contextmanager +def engine_context_manager( # pylint: disable=unused-argument + database: Database, + catalog: str | None, + schema: str | None, +) -> Iterator[None]: + yield None + + +ENGINE_CONTEXT_MANAGER = engine_context_manager + # A callable that allows altering the database connection URL and params # on the fly, at runtime. This allows for things like impersonation or # arbitrary logic. For instance you can wire different users to diff --git a/superset/models/core.py b/superset/models/core.py index 9d432f811..5d3a6ea74 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -418,38 +418,40 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable ) sqlalchemy_uri = self.sqlalchemy_uri_decrypted - engine_context = nullcontext() - ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel( - database_id=self.id - ) - if ssh_tunnel: - # if ssh_tunnel is available build engine with information - engine_context = ssh_manager_factory.instance.create_tunnel( + ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(self.id) + ssh_context_manager = ( + ssh_manager_factory.instance.create_tunnel( ssh_tunnel=ssh_tunnel, sqlalchemy_database_uri=sqlalchemy_uri, ) + if ssh_tunnel + else nullcontext() + ) - with engine_context as server_context: - if ssh_tunnel and server_context: + with ssh_context_manager as ssh_context: + if ssh_context: logger.info( - "[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s ssh_timeout at %s", + "[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s " + "ssh_timeout at %s", sshtunnel.TUNNEL_TIMEOUT, sshtunnel.SSH_TIMEOUT, - server_context.local_bind_address, + ssh_context.local_bind_address, ) sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url( sqlalchemy_uri, - server_context, + ssh_context, ) - yield self._get_sqla_engine( - catalog=catalog, - schema=schema, - nullpool=nullpool, - source=source, - sqlalchemy_uri=sqlalchemy_uri, - ) + engine_context_manager = config["ENGINE_CONTEXT_MANAGER"] + with engine_context_manager(self, catalog, schema): + yield self._get_sqla_engine( + catalog=catalog, + schema=schema, + nullpool=nullpool, + source=source, + sqlalchemy_uri=sqlalchemy_uri, + ) def _get_sqla_engine( # pylint: disable=too-many-locals self, diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 0346020c5..3c591d446 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -233,10 +233,7 @@ def test_get_prequeries(mocker: MockerFixture) -> None: """ Tests for ``get_prequeries``. """ - mocker.patch.object( - Database, - "get_sqla_engine", - ) + mocker.patch.object(Database, "get_sqla_engine") db_engine_spec = mocker.patch.object(Database, "db_engine_spec") db_engine_spec.get_prequeries.return_value = ["set a=1", "set b=2"] @@ -397,10 +394,7 @@ def test_get_sqla_engine(mocker: MockerFixture) -> None: create_engine = mocker.patch("superset.models.core.create_engine") - database = Database( - database_name="my_db", - sqlalchemy_uri="trino://", - ) + database = Database(database_name="my_db", sqlalchemy_uri="trino://") database._get_sqla_engine(nullpool=False) create_engine.assert_called_with( @@ -556,3 +550,30 @@ def test_get_schema_access_for_file_upload() -> None: ) assert database.get_schema_access_for_file_upload() == {"public"} + + +def test_engine_context_manager(mocker: MockerFixture) -> None: + """ + Test the engine context manager. + """ + engine_context_manager = mocker.MagicMock() + mocker.patch( + "superset.models.core.config", + new={"ENGINE_CONTEXT_MANAGER": engine_context_manager}, + ) + _get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine") + + database = Database(database_name="my_db", sqlalchemy_uri="trino://") + with database.get_sqla_engine("catalog", "schema"): + pass + + engine_context_manager.assert_called_once_with(database, "catalog", "schema") + engine_context_manager().__enter__.assert_called_once() + engine_context_manager().__exit__.assert_called_once_with(None, None, None) + _get_sqla_engine.assert_called_once_with( + catalog="catalog", + schema="schema", + nullpool=True, + source=None, + sqlalchemy_uri="trino://", + )