feat(oauth2): add support for trino (#30081)
This commit is contained in:
parent
64f8140731
commit
305b6df6e3
|
|
@ -433,6 +433,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
oauth2_scope = ""
|
||||
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
|
||||
oauth2_token_request_uri: str | None = None
|
||||
oauth2_token_request_type = "data"
|
||||
|
||||
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
||||
oauth2_exception = OAuth2RedirectError
|
||||
|
|
@ -525,6 +526,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
"token_request_uri",
|
||||
cls.oauth2_token_request_uri,
|
||||
),
|
||||
"request_content_type": db_engine_spec_config.get(
|
||||
"request_content_type", cls.oauth2_token_request_type
|
||||
),
|
||||
}
|
||||
|
||||
return config
|
||||
|
|
@ -562,18 +566,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||
uri = config["token_request_uri"]
|
||||
response = requests.post(
|
||||
uri,
|
||||
json={
|
||||
req_body = {
|
||||
"code": code,
|
||||
"client_id": config["id"],
|
||||
"client_secret": config["secret"],
|
||||
"redirect_uri": config["redirect_uri"],
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
return response.json()
|
||||
}
|
||||
if config["request_content_type"] == "data":
|
||||
return requests.post(uri, data=req_body, timeout=timeout).json()
|
||||
return requests.post(uri, json=req_body, timeout=timeout).json()
|
||||
|
||||
@classmethod
|
||||
def get_oauth2_fresh_token(
|
||||
|
|
@ -586,17 +588,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
|
||||
uri = config["token_request_uri"]
|
||||
response = requests.post(
|
||||
uri,
|
||||
json={
|
||||
req_body = {
|
||||
"client_id": config["id"],
|
||||
"client_secret": config["secret"],
|
||||
"refresh_token": refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
return response.json()
|
||||
}
|
||||
if config["request_content_type"] == "data":
|
||||
return requests.post(uri, data=req_body, timeout=timeout).json()
|
||||
return requests.post(uri, json=req_body, timeout=timeout).json()
|
||||
|
||||
@classmethod
|
||||
def get_allows_alias_in_select(
|
||||
|
|
|
|||
|
|
@ -27,11 +27,13 @@ from typing import Any, TYPE_CHECKING
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from flask import ctx, current_app, Flask, g
|
||||
import requests
|
||||
from flask import copy_current_request_context, ctx, current_app, Flask, g
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
from trino.exceptions import HttpError
|
||||
|
||||
from superset import db
|
||||
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
|
||||
|
|
@ -60,11 +62,28 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomTrinoAuthErrorMeta(type):
|
||||
def __instancecheck__(cls, instance: object) -> bool:
|
||||
logger.info("is this being called?")
|
||||
return isinstance(
|
||||
instance, HttpError
|
||||
) and "error 401: b'Invalid credentials'" in str(instance)
|
||||
|
||||
|
||||
class TrinoAuthError(HttpError, metaclass=CustomTrinoAuthErrorMeta):
|
||||
pass
|
||||
|
||||
|
||||
class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||
engine = "trino"
|
||||
engine_name = "Trino"
|
||||
allows_alias_to_source_column = False
|
||||
|
||||
# OAuth 2.0
|
||||
supports_oauth2 = True
|
||||
oauth2_exception = TrinoAuthError
|
||||
oauth2_token_request_type = "data"
|
||||
|
||||
@classmethod
|
||||
def get_extra_table_metadata(
|
||||
cls,
|
||||
|
|
@ -142,6 +161,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
# Set principal_username=$effective_username
|
||||
if backend_name == "trino" and username is not None:
|
||||
connect_args["user"] = username
|
||||
if access_token is not None:
|
||||
http_session = requests.Session()
|
||||
http_session.headers.update({"Authorization": f"Bearer {access_token}"})
|
||||
connect_args["http_session"] = http_session
|
||||
|
||||
@classmethod
|
||||
def get_url_for_impersonation(
|
||||
|
|
@ -154,6 +177,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
"""
|
||||
Return a modified URL with the username set.
|
||||
|
||||
:param access_token: Personal access token for OAuth2
|
||||
:param url: SQLAlchemy URL object
|
||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
|
|
@ -228,6 +252,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
execute_result: dict[str, Any] = {}
|
||||
execute_event = threading.Event()
|
||||
|
||||
@copy_current_request_context
|
||||
def _execute(
|
||||
results: dict[str, Any],
|
||||
event: threading.Event,
|
||||
|
|
|
|||
|
|
@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict):
|
|||
# expired access token.
|
||||
token_request_uri: str
|
||||
|
||||
# Not all identity providers expect json. Keycloak expects a form encoded request,
|
||||
# which in the `requests` package context means using the `data` param, not `json`.
|
||||
request_content_type: str
|
||||
|
||||
|
||||
class OAuth2TokenResponse(TypedDict, total=False):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from typing import Any, TYPE_CHECKING
|
|||
import backoff
|
||||
import jwt
|
||||
from flask import current_app, url_for
|
||||
from marshmallow import EXCLUDE, fields, post_load, Schema
|
||||
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
|
||||
|
||||
from superset import db
|
||||
from superset.distributed_lock import KeyValueDistributedLock
|
||||
|
|
@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema):
|
|||
)
|
||||
authorization_request_uri = fields.String(required=True)
|
||||
token_request_uri = fields.String(required=True)
|
||||
request_content_type = fields.String(
|
||||
required=False,
|
||||
load_default=lambda: "json",
|
||||
validate=validate.OneOf(["json", "data"]),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig:
|
|||
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||
"authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"token_request_uri": "https://oauth2.googleapis.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,12 @@ from superset.db_engine_specs.exceptions import (
|
|||
SupersetDBAPIProgrammingError,
|
||||
)
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
ResultSetColumnType,
|
||||
SQLAColumnType,
|
||||
SQLType,
|
||||
)
|
||||
from superset.utils import json
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.unit_tests.db_engine_specs.utils import (
|
||||
|
|
@ -421,7 +426,9 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
|
|||
def _mock_execute(*args, **kwargs):
|
||||
mock_cursor.query_id = query_id
|
||||
|
||||
with app.test_request_context("/some/place/"):
|
||||
mock_cursor.execute.side_effect = _mock_execute
|
||||
|
||||
with patch.dict(
|
||||
"superset.config.DISALLOWED_SQL_FUNCTIONS",
|
||||
{},
|
||||
|
|
@ -446,12 +453,14 @@ def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
|
|||
mock_cursor.query_id = None
|
||||
|
||||
mock_query = mocker.MagicMock()
|
||||
g.some_value = "some_value"
|
||||
|
||||
def _mock_execute(*args, **kwargs):
|
||||
assert has_app_context()
|
||||
assert g.some_value == "some_value"
|
||||
|
||||
with app.test_request_context("/some/place/"):
|
||||
g.some_value = "some_value"
|
||||
|
||||
with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
|
||||
with patch.dict(
|
||||
"superset.config.DISALLOWED_SQL_FUNCTIONS",
|
||||
|
|
@ -784,3 +793,57 @@ def test_where_latest_partition(
|
|||
)
|
||||
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_config() -> OAuth2ClientConfig:
|
||||
"""
|
||||
Config for Trino OAuth2.
|
||||
"""
|
||||
return {
|
||||
"id": "trino",
|
||||
"secret": "very-secret",
|
||||
"scope": "",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
|
||||
"token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token",
|
||||
"request_content_type": "data",
|
||||
}
|
||||
|
||||
|
||||
def test_get_oauth2_token(
|
||||
mocker: MockerFixture,
|
||||
oauth2_config: OAuth2ClientConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Test `get_oauth2_token`.
|
||||
"""
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
|
||||
requests = mocker.patch("superset.db_engine_specs.base.requests")
|
||||
requests.post().json.return_value = {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "scope",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
|
||||
assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
|
||||
"access_token": "access-token",
|
||||
"expires_in": 3600,
|
||||
"scope": "scope",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "refresh-token",
|
||||
}
|
||||
requests.post.assert_called_with(
|
||||
"https://trino.auth.server.example/master/protocol/openid-connect/token",
|
||||
data={
|
||||
"code": "code",
|
||||
"client_id": "trino",
|
||||
"client_secret": "very-secret",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -521,6 +521,7 @@ def test_get_oauth2_config(app_context: None) -> None:
|
|||
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
|
||||
"scope": "refresh_token session:role:USERADMIN",
|
||||
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue