feat(oauth2): add support for trino (#30081)
This commit is contained in:
parent
64f8140731
commit
305b6df6e3
|
|
@ -433,6 +433,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
oauth2_scope = ""
|
oauth2_scope = ""
|
||||||
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
|
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
|
||||||
oauth2_token_request_uri: str | None = None
|
oauth2_token_request_uri: str | None = None
|
||||||
|
oauth2_token_request_type = "data"
|
||||||
|
|
||||||
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
||||||
oauth2_exception = OAuth2RedirectError
|
oauth2_exception = OAuth2RedirectError
|
||||||
|
|
@ -525,6 +526,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
"token_request_uri",
|
"token_request_uri",
|
||||||
cls.oauth2_token_request_uri,
|
cls.oauth2_token_request_uri,
|
||||||
),
|
),
|
||||||
|
"request_content_type": db_engine_spec_config.get(
|
||||||
|
"request_content_type", cls.oauth2_token_request_type
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
@ -562,18 +566,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
"""
|
"""
|
||||||
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||||
uri = config["token_request_uri"]
|
uri = config["token_request_uri"]
|
||||||
response = requests.post(
|
req_body = {
|
||||||
uri,
|
"code": code,
|
||||||
json={
|
"client_id": config["id"],
|
||||||
"code": code,
|
"client_secret": config["secret"],
|
||||||
"client_id": config["id"],
|
"redirect_uri": config["redirect_uri"],
|
||||||
"client_secret": config["secret"],
|
"grant_type": "authorization_code",
|
||||||
"redirect_uri": config["redirect_uri"],
|
}
|
||||||
"grant_type": "authorization_code",
|
if config["request_content_type"] == "data":
|
||||||
},
|
return requests.post(uri, data=req_body, timeout=timeout).json()
|
||||||
timeout=timeout,
|
return requests.post(uri, json=req_body, timeout=timeout).json()
|
||||||
)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_oauth2_fresh_token(
|
def get_oauth2_fresh_token(
|
||||||
|
|
@ -586,17 +588,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
"""
|
"""
|
||||||
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||||
uri = config["token_request_uri"]
|
uri = config["token_request_uri"]
|
||||||
response = requests.post(
|
req_body = {
|
||||||
uri,
|
"client_id": config["id"],
|
||||||
json={
|
"client_secret": config["secret"],
|
||||||
"client_id": config["id"],
|
"refresh_token": refresh_token,
|
||||||
"client_secret": config["secret"],
|
"grant_type": "refresh_token",
|
||||||
"refresh_token": refresh_token,
|
}
|
||||||
"grant_type": "refresh_token",
|
if config["request_content_type"] == "data":
|
||||||
},
|
return requests.post(uri, data=req_body, timeout=timeout).json()
|
||||||
timeout=timeout,
|
return requests.post(uri, json=req_body, timeout=timeout).json()
|
||||||
)
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_allows_alias_in_select(
|
def get_allows_alias_in_select(
|
||||||
|
|
|
||||||
|
|
@ -27,11 +27,13 @@ from typing import Any, TYPE_CHECKING
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from flask import ctx, current_app, Flask, g
|
import requests
|
||||||
|
from flask import copy_current_request_context, ctx, current_app, Flask, g
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.engine.reflection import Inspector
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
from sqlalchemy.engine.url import URL
|
from sqlalchemy.engine.url import URL
|
||||||
from sqlalchemy.exc import NoSuchTableError
|
from sqlalchemy.exc import NoSuchTableError
|
||||||
|
from trino.exceptions import HttpError
|
||||||
|
|
||||||
from superset import db
|
from superset import db
|
||||||
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
|
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
|
||||||
|
|
@ -60,11 +62,28 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTrinoAuthErrorMeta(type):
|
||||||
|
def __instancecheck__(cls, instance: object) -> bool:
|
||||||
|
logger.info("is this being called?")
|
||||||
|
return isinstance(
|
||||||
|
instance, HttpError
|
||||||
|
) and "error 401: b'Invalid credentials'" in str(instance)
|
||||||
|
|
||||||
|
|
||||||
|
class TrinoAuthError(HttpError, metaclass=CustomTrinoAuthErrorMeta):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TrinoEngineSpec(PrestoBaseEngineSpec):
|
class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||||
engine = "trino"
|
engine = "trino"
|
||||||
engine_name = "Trino"
|
engine_name = "Trino"
|
||||||
allows_alias_to_source_column = False
|
allows_alias_to_source_column = False
|
||||||
|
|
||||||
|
# OAuth 2.0
|
||||||
|
supports_oauth2 = True
|
||||||
|
oauth2_exception = TrinoAuthError
|
||||||
|
oauth2_token_request_type = "data"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_extra_table_metadata(
|
def get_extra_table_metadata(
|
||||||
cls,
|
cls,
|
||||||
|
|
@ -142,6 +161,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||||
# Set principal_username=$effective_username
|
# Set principal_username=$effective_username
|
||||||
if backend_name == "trino" and username is not None:
|
if backend_name == "trino" and username is not None:
|
||||||
connect_args["user"] = username
|
connect_args["user"] = username
|
||||||
|
if access_token is not None:
|
||||||
|
http_session = requests.Session()
|
||||||
|
http_session.headers.update({"Authorization": f"Bearer {access_token}"})
|
||||||
|
connect_args["http_session"] = http_session
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_url_for_impersonation(
|
def get_url_for_impersonation(
|
||||||
|
|
@ -154,6 +177,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||||
"""
|
"""
|
||||||
Return a modified URL with the username set.
|
Return a modified URL with the username set.
|
||||||
|
|
||||||
|
:param access_token: Personal access token for OAuth2
|
||||||
:param url: SQLAlchemy URL object
|
:param url: SQLAlchemy URL object
|
||||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||||
:param username: Effective username
|
:param username: Effective username
|
||||||
|
|
@ -228,6 +252,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||||
execute_result: dict[str, Any] = {}
|
execute_result: dict[str, Any] = {}
|
||||||
execute_event = threading.Event()
|
execute_event = threading.Event()
|
||||||
|
|
||||||
|
@copy_current_request_context
|
||||||
def _execute(
|
def _execute(
|
||||||
results: dict[str, Any],
|
results: dict[str, Any],
|
||||||
event: threading.Event,
|
event: threading.Event,
|
||||||
|
|
|
||||||
|
|
@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict):
|
||||||
# expired access token.
|
# expired access token.
|
||||||
token_request_uri: str
|
token_request_uri: str
|
||||||
|
|
||||||
|
# Not all identity providers expect json. Keycloak expects a form encoded request,
|
||||||
|
# which in the `requests` package context means using the `data` param, not `json`.
|
||||||
|
request_content_type: str
|
||||||
|
|
||||||
|
|
||||||
class OAuth2TokenResponse(TypedDict, total=False):
|
class OAuth2TokenResponse(TypedDict, total=False):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from typing import Any, TYPE_CHECKING
|
||||||
import backoff
|
import backoff
|
||||||
import jwt
|
import jwt
|
||||||
from flask import current_app, url_for
|
from flask import current_app, url_for
|
||||||
from marshmallow import EXCLUDE, fields, post_load, Schema
|
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
||||||
|
|
||||||
from superset import db
|
from superset import db
|
||||||
from superset.distributed_lock import KeyValueDistributedLock
|
from superset.distributed_lock import KeyValueDistributedLock
|
||||||
|
|
@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema):
|
||||||
)
|
)
|
||||||
authorization_request_uri = fields.String(required=True)
|
authorization_request_uri = fields.String(required=True)
|
||||||
token_request_uri = fields.String(required=True)
|
token_request_uri = fields.String(required=True)
|
||||||
|
request_content_type = fields.String(
|
||||||
|
required=False,
|
||||||
|
load_default=lambda: "json",
|
||||||
|
validate=validate.OneOf(["json", "data"]),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig:
|
||||||
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||||
"authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth",
|
"authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||||
"token_request_uri": "https://oauth2.googleapis.com/token",
|
"token_request_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"request_content_type": "json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,12 @@ from superset.db_engine_specs.exceptions import (
|
||||||
SupersetDBAPIProgrammingError,
|
SupersetDBAPIProgrammingError,
|
||||||
)
|
)
|
||||||
from superset.sql_parse import Table
|
from superset.sql_parse import Table
|
||||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
|
from superset.superset_typing import (
|
||||||
|
OAuth2ClientConfig,
|
||||||
|
ResultSetColumnType,
|
||||||
|
SQLAColumnType,
|
||||||
|
SQLType,
|
||||||
|
)
|
||||||
from superset.utils import json
|
from superset.utils import json
|
||||||
from superset.utils.core import GenericDataType
|
from superset.utils.core import GenericDataType
|
||||||
from tests.unit_tests.db_engine_specs.utils import (
|
from tests.unit_tests.db_engine_specs.utils import (
|
||||||
|
|
@ -421,38 +426,9 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
|
||||||
def _mock_execute(*args, **kwargs):
|
def _mock_execute(*args, **kwargs):
|
||||||
mock_cursor.query_id = query_id
|
mock_cursor.query_id = query_id
|
||||||
|
|
||||||
mock_cursor.execute.side_effect = _mock_execute
|
with app.test_request_context("/some/place/"):
|
||||||
with patch.dict(
|
mock_cursor.execute.side_effect = _mock_execute
|
||||||
"superset.config.DISALLOWED_SQL_FUNCTIONS",
|
|
||||||
{},
|
|
||||||
clear=True,
|
|
||||||
):
|
|
||||||
TrinoEngineSpec.execute_with_cursor(
|
|
||||||
cursor=mock_cursor,
|
|
||||||
sql="SELECT 1 FROM foo",
|
|
||||||
query=mock_query,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_query.set_extra_json_key.assert_called_once_with(
|
|
||||||
key=QUERY_CANCEL_KEY, value=query_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
|
|
||||||
"""Test that `execute_with_cursor` still contains the current app context"""
|
|
||||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
||||||
|
|
||||||
mock_cursor = mocker.MagicMock()
|
|
||||||
mock_cursor.query_id = None
|
|
||||||
|
|
||||||
mock_query = mocker.MagicMock()
|
|
||||||
g.some_value = "some_value"
|
|
||||||
|
|
||||||
def _mock_execute(*args, **kwargs):
|
|
||||||
assert has_app_context()
|
|
||||||
assert g.some_value == "some_value"
|
|
||||||
|
|
||||||
with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
|
|
||||||
with patch.dict(
|
with patch.dict(
|
||||||
"superset.config.DISALLOWED_SQL_FUNCTIONS",
|
"superset.config.DISALLOWED_SQL_FUNCTIONS",
|
||||||
{},
|
{},
|
||||||
|
|
@ -464,6 +440,39 @@ def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
|
||||||
query=mock_query,
|
query=mock_query,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mock_query.set_extra_json_key.assert_called_once_with(
|
||||||
|
key=QUERY_CANCEL_KEY, value=query_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
|
||||||
|
"""Test that `execute_with_cursor` still contains the current app context"""
|
||||||
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||||
|
|
||||||
|
mock_cursor = mocker.MagicMock()
|
||||||
|
mock_cursor.query_id = None
|
||||||
|
|
||||||
|
mock_query = mocker.MagicMock()
|
||||||
|
|
||||||
|
def _mock_execute(*args, **kwargs):
|
||||||
|
assert has_app_context()
|
||||||
|
assert g.some_value == "some_value"
|
||||||
|
|
||||||
|
with app.test_request_context("/some/place/"):
|
||||||
|
g.some_value = "some_value"
|
||||||
|
|
||||||
|
with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
|
||||||
|
with patch.dict(
|
||||||
|
"superset.config.DISALLOWED_SQL_FUNCTIONS",
|
||||||
|
{},
|
||||||
|
clear=True,
|
||||||
|
):
|
||||||
|
TrinoEngineSpec.execute_with_cursor(
|
||||||
|
cursor=mock_cursor,
|
||||||
|
sql="SELECT 1 FROM foo",
|
||||||
|
query=mock_query,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_columns(mocker: MockerFixture):
|
def test_get_columns(mocker: MockerFixture):
|
||||||
"""Test that ROW columns are not expanded without expand_rows"""
|
"""Test that ROW columns are not expanded without expand_rows"""
|
||||||
|
|
@ -784,3 +793,57 @@ def test_where_latest_partition(
|
||||||
)
|
)
|
||||||
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
|
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth2_config() -> OAuth2ClientConfig:
|
||||||
|
"""
|
||||||
|
Config for Trino OAuth2.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"id": "trino",
|
||||||
|
"secret": "very-secret",
|
||||||
|
"scope": "",
|
||||||
|
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||||
|
"authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
|
||||||
|
"token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token",
|
||||||
|
"request_content_type": "data",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_oauth2_token(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
oauth2_config: OAuth2ClientConfig,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test `get_oauth2_token`.
|
||||||
|
"""
|
||||||
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||||
|
|
||||||
|
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||||
|
requests.post().json.return_value = {
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"scope": "scope",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
|
||||||
|
"access_token": "access-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"scope": "scope",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"refresh_token": "refresh-token",
|
||||||
|
}
|
||||||
|
requests.post.assert_called_with(
|
||||||
|
"https://trino.auth.server.example/master/protocol/openid-connect/token",
|
||||||
|
data={
|
||||||
|
"code": "code",
|
||||||
|
"client_id": "trino",
|
||||||
|
"client_secret": "very-secret",
|
||||||
|
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -521,6 +521,7 @@ def test_get_oauth2_config(app_context: None) -> None:
|
||||||
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
|
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
|
||||||
"scope": "refresh_token session:role:USERADMIN",
|
"scope": "refresh_token session:role:USERADMIN",
|
||||||
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
|
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
|
||||||
|
"request_content_type": "json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue