From 6294e339e2f3398d93ed4e3da4ea82aefd7945d5 Mon Sep 17 00:00:00 2001 From: Patrick Schmidt Date: Fri, 6 Sep 2024 18:13:38 +0200 Subject: [PATCH] feat(db_engine): Implement user impersonation support for StarRocks (#28110) --- superset/db_engine_specs/README.md | 2 +- superset/db_engine_specs/base.py | 1 + superset/db_engine_specs/databricks.py | 1 + superset/db_engine_specs/db2.py | 2 + superset/db_engine_specs/postgres.py | 1 + superset/db_engine_specs/starrocks.py | 50 ++++++++++++++++++- superset/models/core.py | 1 + .../db_engine_specs/test_databricks.py | 14 ++++-- tests/unit_tests/db_engine_specs/test_db2.py | 8 +-- .../db_engine_specs/test_postgres.py | 8 +-- .../db_engine_specs/test_starrocks.py | 45 +++++++++++++++++ 11 files changed, 120 insertions(+), 13 deletions(-) diff --git a/superset/db_engine_specs/README.md b/superset/db_engine_specs/README.md index 9cdd520d5..75ab73515 100644 --- a/superset/db_engine_specs/README.md +++ b/superset/db_engine_specs/README.md @@ -95,7 +95,7 @@ The table below (generated via `python superset/db_engine_specs/lib.py`) summari | Masks/unmasks encrypted_extra | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | | Has column type mappings | False | False | False | False | False | True | False | False | False | False | True | False | True | True | True | True | True | True | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | True | True | True | False | False | False | True | True | False | False | False | False | False | True | False | True | False | True | | Returns a list of function names | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | True | True | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | True | True | False | False | False | True | False | True | -| Supports user impersonation | False | False | False | True | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False | +| Supports user impersonation | False | False | False | True | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | True | False | True | False | False | | Support file upload | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | False | False | True | True | True | True | True | True | True | True | True | True | True | True | True | False | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | | Returns extra table metadata | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | True | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False | | Maps driver exceptions to Superset exceptions | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 678649e85..2b32d156d 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1408,6 +1408,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_prequeries( cls, + database: Database, # pylint: disable=unused-argument catalog: str | None = None, # pylint: disable=unused-argument schema: str | None = None, # pylint: disable=unused-argument ) -> list[str]: diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 4f66e2fdc..88d6407a3 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -458,6 +458,7 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec): @classmethod def get_prequeries( cls, + database: Database, catalog: str | None = None, schema: str | None = None, ) -> list[str]: diff --git a/superset/db_engine_specs/db2.py b/superset/db_engine_specs/db2.py index b2151767d..6781701ac 100644 --- a/superset/db_engine_specs/db2.py +++ b/superset/db_engine_specs/db2.py @@ -21,6 +21,7 @@ from sqlalchemy.engine.reflection import Inspector from superset.constants import TimeGrain from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +from superset.models.core import Database from superset.sql_parse import Table logger = logging.getLogger(__name__) @@ -93,6 +94,7 @@ class Db2EngineSpec(BaseEngineSpec): @classmethod def get_prequeries( cls, + database: Database, catalog: Union[str, None] = None, schema: Union[str, None] = None, ) -> list[str]: diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 015d5c52f..8525ea05d 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -322,6 +322,7 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): @classmethod def get_prequeries( cls, + database: Database, catalog: str | None = None, schema: str | None = None, ) -> list[str]: diff --git a/superset/db_engine_specs/starrocks.py b/superset/db_engine_specs/starrocks.py index 9f3cfd764..6f54329d6 100644 --- a/superset/db_engine_specs/starrocks.py +++ b/superset/db_engine_specs/starrocks.py @@ -18,7 +18,7 @@ import logging import re from re import Pattern -from typing import Any, Optional +from typing import Any, Optional, Union from urllib import parse from flask_babel import gettext as __ @@ -28,6 +28,7 @@ from sqlalchemy.sql.type_api import TypeEngine from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.errors import SupersetErrorType +from superset.models.core import Database from superset.utils.core import GenericDataType # Regular expressions to catch custom errors @@ -201,3 +202,50 @@ class StarRocksEngineSpec(MySQLEngineSpec): return None return parse.unquote(database.split(".")[1]) + + @classmethod + def get_url_for_impersonation( + cls, + url: URL, + impersonate_user: bool, + username: Union[str, None] = None, + access_token: Union[str, None] = None, + ) -> URL: + """ + Return a modified URL with the username set. + + :param url: SQLAlchemy URL object + :param impersonate_user: Flag indicating if impersonation is enabled + :param username: Effective username + :param access_token: Personal access token + """ + # Leave URL unchanged. We will impersonate with the pre-query below. + return url + + @classmethod + def get_prequeries( + cls, + database: Database, + catalog: Union[str, None] = None, + schema: Union[str, None] = None, + ) -> list[str]: + """ + Return pre-session queries. + + These are currently used as an alternative to ``adjust_engine_params`` for + databases where the selected schema cannot be specified in the SQLAlchemy URI or + connection arguments. + + For example, in order to specify a default schema in RDS we need to run a query + at the beginning of the session: + + sql> set search_path = my_schema; + + """ + if database.impersonate_user: + username = database.get_effective_user(database.url_object) + + if username: + return [f'EXECUTE AS "{username}" WITH NO REVERT;'] + + return [] diff --git a/superset/models/core.py b/superset/models/core.py index c528d3580..9d432f811 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -560,6 +560,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable # pre-session queries are used to set the selected schema and, in the # future, the selected catalog for prequery in self.db_engine_spec.get_prequeries( + database=self, catalog=catalog, schema=schema, ): diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py b/tests/unit_tests/db_engine_specs/test_databricks.py index 652387471..025784fe2 100644 --- a/tests/unit_tests/db_engine_specs/test_databricks.py +++ b/tests/unit_tests/db_engine_specs/test_databricks.py @@ -247,20 +247,24 @@ def test_convert_dttm( assert_convert_dttm(spec, target_type, expected_result, dttm) -def test_get_prequeries() -> None: +def test_get_prequeries(mocker: MockerFixture) -> None: """ Test the ``get_prequeries`` method. """ from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec - assert DatabricksNativeEngineSpec.get_prequeries() == [] - assert DatabricksNativeEngineSpec.get_prequeries(schema="test") == [ + database = mocker.MagicMock() + + assert DatabricksNativeEngineSpec.get_prequeries(database) == [] + assert DatabricksNativeEngineSpec.get_prequeries(database, schema="test") == [ "USE SCHEMA test", ] - assert DatabricksNativeEngineSpec.get_prequeries(catalog="test") == [ + assert DatabricksNativeEngineSpec.get_prequeries(database, catalog="test") == [ "USE CATALOG test", ] - assert DatabricksNativeEngineSpec.get_prequeries(catalog="foo", schema="bar") == [ + assert DatabricksNativeEngineSpec.get_prequeries( + database, catalog="foo", schema="bar" + ) == [ "USE CATALOG foo", "USE SCHEMA bar", ] diff --git a/tests/unit_tests/db_engine_specs/test_db2.py b/tests/unit_tests/db_engine_specs/test_db2.py index 017fcd7b8..3102c4377 100644 --- a/tests/unit_tests/db_engine_specs/test_db2.py +++ b/tests/unit_tests/db_engine_specs/test_db2.py @@ -66,13 +66,15 @@ def test_get_table_comment_empty(mocker: MockerFixture): ) -def test_get_prequeries() -> None: +def test_get_prequeries(mocker: MockerFixture) -> None: """ Test the ``get_prequeries`` method. """ from superset.db_engine_specs.db2 import Db2EngineSpec - assert Db2EngineSpec.get_prequeries() == [] - assert Db2EngineSpec.get_prequeries(schema="my_schema") == [ + database = mocker.MagicMock() + + assert Db2EngineSpec.get_prequeries(database) == [] + assert Db2EngineSpec.get_prequeries(database, schema="my_schema") == [ 'set current_schema "my_schema"' ] diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index 32b8bb613..da5a4ccf8 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -137,14 +137,16 @@ def test_get_schema_from_engine_params() -> None: ) -def test_get_prequeries() -> None: +def test_get_prequeries(mocker: MockerFixture) -> None: """ Test the ``get_prequeries`` method. """ from superset.db_engine_specs.postgres import PostgresEngineSpec - assert PostgresEngineSpec.get_prequeries() == [] - assert PostgresEngineSpec.get_prequeries(schema="test") == [ + database = mocker.MagicMock() + + assert PostgresEngineSpec.get_prequeries(database) == [] + assert PostgresEngineSpec.get_prequeries(database, schema="test") == [ 'set search_path = "test"' ] diff --git a/tests/unit_tests/db_engine_specs/test_starrocks.py b/tests/unit_tests/db_engine_specs/test_starrocks.py index 5d89f5b81..c167755a1 100644 --- a/tests/unit_tests/db_engine_specs/test_starrocks.py +++ b/tests/unit_tests/db_engine_specs/test_starrocks.py @@ -18,6 +18,7 @@ from typing import Any, Optional import pytest +from pytest_mock import MockerFixture from sqlalchemy import JSON, types from sqlalchemy.engine.url import make_url @@ -124,3 +125,47 @@ def test_get_schema_from_engine_params() -> None: ) is None ) + + +def test_impersonation_username(mocker: MockerFixture) -> None: + """ + Test impersonation and make sure that `get_url_for_impersonation` leaves the URL + unchanged and that `get_prequeries` returns the appropriate impersonation query. + """ + from superset.db_engine_specs.starrocks import StarRocksEngineSpec + + database = mocker.MagicMock() + database.impersonate_user = True + database.get_effective_user.return_value = "alice" + + assert StarRocksEngineSpec.get_url_for_impersonation( + url=make_url("starrocks://service_user@localhost:9030/hive.default"), + impersonate_user=True, + username="alice", + access_token=None, + ) == make_url("starrocks://service_user@localhost:9030/hive.default") + + assert StarRocksEngineSpec.get_prequeries(database) == [ + 'EXECUTE AS "alice" WITH NO REVERT;' + ] + + +def test_impersonation_disabled(mocker: MockerFixture) -> None: + """ + Test that impersonation is not applied when the feature is disabled in + `get_url_for_impersonation` and `get_prequeries`. + """ + from superset.db_engine_specs.starrocks import StarRocksEngineSpec + + database = mocker.MagicMock() + database.impersonate_user = False + database.get_effective_user.return_value = "alice" + + assert StarRocksEngineSpec.get_url_for_impersonation( + url=make_url("starrocks://service_user@localhost:9030/hive.default"), + impersonate_user=False, + username="alice", + access_token=None, + ) == make_url("starrocks://service_user@localhost:9030/hive.default") + + assert StarRocksEngineSpec.get_prequeries(database) == []