diff --git a/RESOURCES/FEATURE_FLAGS.md b/RESOURCES/FEATURE_FLAGS.md index a8796084f..950e652ef 100644 --- a/RESOURCES/FEATURE_FLAGS.md +++ b/RESOURCES/FEATURE_FLAGS.md @@ -50,6 +50,7 @@ These features are **finished** but currently being tested. They are usable, but - ESTIMATE_QUERY_COST - GLOBAL_ASYNC_QUERIES [(docs)](https://github.com/apache/superset/blob/master/CONTRIBUTING.md#async-chart-queries) - HORIZONTAL_FILTER_BAR +- IMPERSONATE_WITH_EMAIL_PREFIX - PLAYWRIGHT_REPORTS_AND_THUMBNAILS - RLS_IN_SQLLAB - SSH_TUNNELING [(docs)](https://superset.apache.org/docs/configuration/setup-ssh-tunneling) diff --git a/superset/config.py b/superset/config.py index aa8178d08..cb7798299 100644 --- a/superset/config.py +++ b/superset/config.py @@ -461,6 +461,8 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = { # Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the # query, and might break queries and/or allow users to bypass RLS. Use with care! "RLS_IN_SQLLAB": False, + # When impersonating a user, use the email prefix instead of the username + "IMPERSONATE_WITH_EMAIL_PREFIX": False, # Enable caching per impersonation key (e.g username) in a datasource where user # impersonation is enabled "CACHE_IMPERSONATION": False, diff --git a/superset/models/core.py b/superset/models/core.py index b933c1694..e6d97a197 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -59,7 +59,7 @@ from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import ColumnElement, expression, Select -from superset import app, db_engine_specs +from superset import app, db_engine_specs, is_feature_enabled from superset.commands.database.exceptions import DatabaseInvalidError from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK from superset.databases.utils import make_url_safe @@ -450,7 +450,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable sqlalchemy_uri=sqlalchemy_uri, ) - def _get_sqla_engine( + def _get_sqla_engine( # pylint: disable=too-many-locals self, catalog: str | None = None, schema: str | None = None, @@ -477,6 +477,11 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable ) effective_username = self.get_effective_user(sqlalchemy_url) + if effective_username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"): + user = security_manager.find_user(username=effective_username) + if user and user.email: + effective_username = user.email.split("@")[0] + oauth2_config = self.get_oauth2_config() access_token = ( get_oauth2_access_token( diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index e653eee71..2004ff482 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -21,11 +21,13 @@ from datetime import datetime import pytest from pytest_mock import MockFixture from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.engine.url import make_url from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.models.core import Database from superset.sql_parse import Table from superset.utils import json +from tests.unit_tests.conftest import with_feature_flags def test_get_metrics(mocker: MockFixture) -> None: @@ -289,3 +291,90 @@ def test_get_all_catalog_names(mocker: MockFixture) -> None: assert database.get_all_catalog_names(force=True) == {"examples", "other"} get_inspector.assert_called_with(ssh_tunnel=None) + + +def test_get_sqla_engine(mocker: MockFixture) -> None: + """ + Test `_get_sqla_engine`. + """ + from superset.models.core import Database + + user = mocker.MagicMock() + user.email = "alice.doe@example.org" + mocker.patch( + "superset.models.core.security_manager.find_user", + return_value=user, + ) + mocker.patch("superset.models.core.get_username", return_value="alice") + + create_engine = mocker.patch("superset.models.core.create_engine") + + database = Database( + database_name="my_db", + sqlalchemy_uri="trino://", + ) + database._get_sqla_engine(nullpool=False) + + create_engine.assert_called_with( + make_url("trino:///"), + connect_args={"source": "Apache Superset"}, + ) + + +def test_get_sqla_engine_user_impersonation(mocker: MockFixture) -> None: + """ + Test user impersonation in `_get_sqla_engine`. + """ + from superset.models.core import Database + + user = mocker.MagicMock() + user.email = "alice.doe@example.org" + mocker.patch( + "superset.models.core.security_manager.find_user", + return_value=user, + ) + mocker.patch("superset.models.core.get_username", return_value="alice") + + create_engine = mocker.patch("superset.models.core.create_engine") + + database = Database( + database_name="my_db", + sqlalchemy_uri="trino://", + impersonate_user=True, + ) + database._get_sqla_engine(nullpool=False) + + create_engine.assert_called_with( + make_url("trino:///"), + connect_args={"user": "alice", "source": "Apache Superset"}, + ) + + +@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True) +def test_get_sqla_engine_user_impersonation_email(mocker: MockFixture) -> None: + """ + Test user impersonation in `_get_sqla_engine` with `username_from_email`. + """ + from superset.models.core import Database + + user = mocker.MagicMock() + user.email = "alice.doe@example.org" + mocker.patch( + "superset.models.core.security_manager.find_user", + return_value=user, + ) + mocker.patch("superset.models.core.get_username", return_value="alice") + + create_engine = mocker.patch("superset.models.core.create_engine") + + database = Database( + database_name="my_db", + sqlalchemy_uri="trino://", + impersonate_user=True, + ) + database._get_sqla_engine(nullpool=False) + + create_engine.assert_called_with( + make_url("trino:///"), + connect_args={"user": "alice.doe", "source": "Apache Superset"}, + )