feat(db_engine): Implement user impersonation support for StarRocks (#28110)
This commit is contained in:
parent
d3f5c795ff
commit
6294e339e2
|
|
@ -95,7 +95,7 @@ The table below (generated via `python superset/db_engine_specs/lib.py`) summari
|
|||
| Masks/unmasks encrypted_extra | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
|
||||
| Has column type mappings | False | False | False | False | False | True | False | False | False | False | True | False | True | True | True | True | True | True | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | True | True | True | False | False | False | True | True | False | False | False | False | False | True | False | True | False | True |
|
||||
| Returns a list of function names | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | True | True | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | True | True | False | False | False | True | False | True |
|
||||
| Supports user impersonation | False | False | False | True | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False |
|
||||
| Supports user impersonation | False | False | False | True | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | True | False | True | False | False |
|
||||
| Support file upload | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | False | False | True | True | True | True | True | True | True | True | True | True | True | True | True | False | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True |
|
||||
| Returns extra table metadata | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | True | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False |
|
||||
| Maps driver exceptions to Superset exceptions | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
|
||||
|
|
|
|||
|
|
@ -1408,6 +1408,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
@classmethod
|
||||
def get_prequeries(
|
||||
cls,
|
||||
database: Database, # pylint: disable=unused-argument
|
||||
catalog: str | None = None, # pylint: disable=unused-argument
|
||||
schema: str | None = None, # pylint: disable=unused-argument
|
||||
) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -458,6 +458,7 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
|||
@classmethod
|
||||
def get_prequeries(
|
||||
cls,
|
||||
database: Database,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from sqlalchemy.engine.reflection import Inspector
|
|||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
||||
from superset.models.core import Database
|
||||
from superset.sql_parse import Table
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -93,6 +94,7 @@ class Db2EngineSpec(BaseEngineSpec):
|
|||
@classmethod
|
||||
def get_prequeries(
|
||||
cls,
|
||||
database: Database,
|
||||
catalog: Union[str, None] = None,
|
||||
schema: Union[str, None] = None,
|
||||
) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -322,6 +322,7 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
|||
@classmethod
|
||||
def get_prequeries(
|
||||
cls,
|
||||
database: Database,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@
|
|||
import logging
|
||||
import re
|
||||
from re import Pattern
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
from urllib import parse
|
||||
|
||||
from flask_babel import gettext as __
|
||||
|
|
@ -28,6 +28,7 @@ from sqlalchemy.sql.type_api import TypeEngine
|
|||
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
# Regular expressions to catch custom errors
|
||||
|
|
@ -201,3 +202,50 @@ class StarRocksEngineSpec(MySQLEngineSpec):
|
|||
return None
|
||||
|
||||
return parse.unquote(database.split(".")[1])
|
||||
|
||||
@classmethod
|
||||
def get_url_for_impersonation(
|
||||
cls,
|
||||
url: URL,
|
||||
impersonate_user: bool,
|
||||
username: Union[str, None] = None,
|
||||
access_token: Union[str, None] = None,
|
||||
) -> URL:
|
||||
"""
|
||||
Return a modified URL with the username set.
|
||||
|
||||
:param url: SQLAlchemy URL object
|
||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
:param access_token: Personal access token
|
||||
"""
|
||||
# Leave URL unchanged. We will impersonate with the pre-query below.
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def get_prequeries(
|
||||
cls,
|
||||
database: Database,
|
||||
catalog: Union[str, None] = None,
|
||||
schema: Union[str, None] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Return pre-session queries.
|
||||
|
||||
These are currently used as an alternative to ``adjust_engine_params`` for
|
||||
databases where the selected schema cannot be specified in the SQLAlchemy URI or
|
||||
connection arguments.
|
||||
|
||||
For example, in order to specify a default schema in RDS we need to run a query
|
||||
at the beginning of the session:
|
||||
|
||||
sql> set search_path = my_schema;
|
||||
|
||||
"""
|
||||
if database.impersonate_user:
|
||||
username = database.get_effective_user(database.url_object)
|
||||
|
||||
if username:
|
||||
return [f'EXECUTE AS "{username}" WITH NO REVERT;']
|
||||
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -560,6 +560,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
# pre-session queries are used to set the selected schema and, in the
|
||||
# future, the selected catalog
|
||||
for prequery in self.db_engine_spec.get_prequeries(
|
||||
database=self,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -247,20 +247,24 @@ def test_convert_dttm(
|
|||
assert_convert_dttm(spec, target_type, expected_result, dttm)
|
||||
|
||||
|
||||
def test_get_prequeries() -> None:
|
||||
def test_get_prequeries(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the ``get_prequeries`` method.
|
||||
"""
|
||||
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec
|
||||
|
||||
assert DatabricksNativeEngineSpec.get_prequeries() == []
|
||||
assert DatabricksNativeEngineSpec.get_prequeries(schema="test") == [
|
||||
database = mocker.MagicMock()
|
||||
|
||||
assert DatabricksNativeEngineSpec.get_prequeries(database) == []
|
||||
assert DatabricksNativeEngineSpec.get_prequeries(database, schema="test") == [
|
||||
"USE SCHEMA test",
|
||||
]
|
||||
assert DatabricksNativeEngineSpec.get_prequeries(catalog="test") == [
|
||||
assert DatabricksNativeEngineSpec.get_prequeries(database, catalog="test") == [
|
||||
"USE CATALOG test",
|
||||
]
|
||||
assert DatabricksNativeEngineSpec.get_prequeries(catalog="foo", schema="bar") == [
|
||||
assert DatabricksNativeEngineSpec.get_prequeries(
|
||||
database, catalog="foo", schema="bar"
|
||||
) == [
|
||||
"USE CATALOG foo",
|
||||
"USE SCHEMA bar",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -66,13 +66,15 @@ def test_get_table_comment_empty(mocker: MockerFixture):
|
|||
)
|
||||
|
||||
|
||||
def test_get_prequeries() -> None:
|
||||
def test_get_prequeries(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the ``get_prequeries`` method.
|
||||
"""
|
||||
from superset.db_engine_specs.db2 import Db2EngineSpec
|
||||
|
||||
assert Db2EngineSpec.get_prequeries() == []
|
||||
assert Db2EngineSpec.get_prequeries(schema="my_schema") == [
|
||||
database = mocker.MagicMock()
|
||||
|
||||
assert Db2EngineSpec.get_prequeries(database) == []
|
||||
assert Db2EngineSpec.get_prequeries(database, schema="my_schema") == [
|
||||
'set current_schema "my_schema"'
|
||||
]
|
||||
|
|
|
|||
|
|
@ -137,14 +137,16 @@ def test_get_schema_from_engine_params() -> None:
|
|||
)
|
||||
|
||||
|
||||
def test_get_prequeries() -> None:
|
||||
def test_get_prequeries(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the ``get_prequeries`` method.
|
||||
"""
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
|
||||
assert PostgresEngineSpec.get_prequeries() == []
|
||||
assert PostgresEngineSpec.get_prequeries(schema="test") == [
|
||||
database = mocker.MagicMock()
|
||||
|
||||
assert PostgresEngineSpec.get_prequeries(database) == []
|
||||
assert PostgresEngineSpec.get_prequeries(database, schema="test") == [
|
||||
'set search_path = "test"'
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import JSON, types
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
|
|
@ -124,3 +125,47 @@ def test_get_schema_from_engine_params() -> None:
|
|||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_impersonation_username(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test impersonation and make sure that `get_url_for_impersonation` leaves the URL
|
||||
unchanged and that `get_prequeries` returns the appropriate impersonation query.
|
||||
"""
|
||||
from superset.db_engine_specs.starrocks import StarRocksEngineSpec
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.impersonate_user = True
|
||||
database.get_effective_user.return_value = "alice"
|
||||
|
||||
assert StarRocksEngineSpec.get_url_for_impersonation(
|
||||
url=make_url("starrocks://service_user@localhost:9030/hive.default"),
|
||||
impersonate_user=True,
|
||||
username="alice",
|
||||
access_token=None,
|
||||
) == make_url("starrocks://service_user@localhost:9030/hive.default")
|
||||
|
||||
assert StarRocksEngineSpec.get_prequeries(database) == [
|
||||
'EXECUTE AS "alice" WITH NO REVERT;'
|
||||
]
|
||||
|
||||
|
||||
def test_impersonation_disabled(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that impersonation is not applied when the feature is disabled in
|
||||
`get_url_for_impersonation` and `get_prequeries`.
|
||||
"""
|
||||
from superset.db_engine_specs.starrocks import StarRocksEngineSpec
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.impersonate_user = False
|
||||
database.get_effective_user.return_value = "alice"
|
||||
|
||||
assert StarRocksEngineSpec.get_url_for_impersonation(
|
||||
url=make_url("starrocks://service_user@localhost:9030/hive.default"),
|
||||
impersonate_user=False,
|
||||
username="alice",
|
||||
access_token=None,
|
||||
) == make_url("starrocks://service_user@localhost:9030/hive.default")
|
||||
|
||||
assert StarRocksEngineSpec.get_prequeries(database) == []
|
||||
|
|
|
|||
Loading…
Reference in New Issue