feat(oauth2): add support for trino (#30081)

This commit is contained in:
João Ferrão 2024-11-04 17:54:47 +01:00 committed by GitHub
parent 64f8140731
commit 305b6df6e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 156 additions and 57 deletions

View File

@ -433,6 +433,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
oauth2_scope = ""
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
oauth2_token_request_uri: str | None = None
oauth2_token_request_type = "data"
# Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError
@ -525,6 +526,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"token_request_uri",
cls.oauth2_token_request_uri,
),
"request_content_type": db_engine_spec_config.get(
"request_content_type", cls.oauth2_token_request_type
),
}
return config
@ -562,18 +566,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
req_body = {
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
},
timeout=timeout,
)
return response.json()
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod
def get_oauth2_fresh_token(
@ -586,17 +588,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
response = requests.post(
uri,
json={
req_body = {
"client_id": config["id"],
"client_secret": config["secret"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
},
timeout=timeout,
)
return response.json()
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod
def get_allows_alias_in_select(

View File

@ -27,11 +27,13 @@ from typing import Any, TYPE_CHECKING
import numpy as np
import pandas as pd
import pyarrow as pa
from flask import ctx, current_app, Flask, g
import requests
from flask import copy_current_request_context, ctx, current_app, Flask, g
from sqlalchemy import text
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
from trino.exceptions import HttpError
from superset import db
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
@ -60,11 +62,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class CustomTrinoAuthErrorMeta(type):
def __instancecheck__(cls, instance: object) -> bool:
logger.info("is this being called?")
return isinstance(
instance, HttpError
) and "error 401: b'Invalid credentials'" in str(instance)
class TrinoAuthError(HttpError, metaclass=CustomTrinoAuthErrorMeta):
pass
class TrinoEngineSpec(PrestoBaseEngineSpec):
engine = "trino"
engine_name = "Trino"
allows_alias_to_source_column = False
# OAuth 2.0
supports_oauth2 = True
oauth2_exception = TrinoAuthError
oauth2_token_request_type = "data"
@classmethod
def get_extra_table_metadata(
cls,
@ -142,6 +161,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
# Set principal_username=$effective_username
if backend_name == "trino" and username is not None:
connect_args["user"] = username
if access_token is not None:
http_session = requests.Session()
http_session.headers.update({"Authorization": f"Bearer {access_token}"})
connect_args["http_session"] = http_session
@classmethod
def get_url_for_impersonation(
@ -154,6 +177,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
"""
Return a modified URL with the username set.
:param access_token: Personal access token for OAuth2
:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
@ -228,6 +252,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_result: dict[str, Any] = {}
execute_event = threading.Event()
@copy_current_request_context
def _execute(
results: dict[str, Any],
event: threading.Event,

View File

@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict):
# expired access token.
token_request_uri: str
# Not all identity providers expect json. Keycloak expects a form encoded request,
# which in the `requests` package context means using the `data` param, not `json`.
request_content_type: str
class OAuth2TokenResponse(TypedDict, total=False):
"""

View File

@ -23,7 +23,7 @@ from typing import Any, TYPE_CHECKING
import backoff
import jwt
from flask import current_app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from superset import db
from superset.distributed_lock import KeyValueDistributedLock
@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema):
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
request_content_type = fields.String(
required=False,
load_default=lambda: "json",
validate=validate.OneOf(["json", "data"]),
)

View File

@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig:
"redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth",
"token_request_uri": "https://oauth2.googleapis.com/token",
"request_content_type": "json",
}

View File

@ -45,7 +45,12 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIProgrammingError,
)
from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType
from superset.superset_typing import (
OAuth2ClientConfig,
ResultSetColumnType,
SQLAColumnType,
SQLType,
)
from superset.utils import json
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
@ -421,7 +426,9 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id
with app.test_request_context("/some/place/"):
mock_cursor.execute.side_effect = _mock_execute
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
@ -446,12 +453,14 @@ def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
mock_cursor.query_id = None
mock_query = mocker.MagicMock()
g.some_value = "some_value"
def _mock_execute(*args, **kwargs):
assert has_app_context()
assert g.some_value == "some_value"
with app.test_request_context("/some/place/"):
g.some_value = "some_value"
with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
@ -784,3 +793,57 @@ def test_where_latest_partition(
)
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
)
@pytest.fixture
def oauth2_config() -> OAuth2ClientConfig:
"""
Config for Trino OAuth2.
"""
return {
"id": "trino",
"secret": "very-secret",
"scope": "",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
"token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token",
"request_content_type": "data",
}
def test_get_oauth2_token(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://trino.auth.server.example/master/protocol/openid-connect/token",
data={
"code": "code",
"client_id": "trino",
"client_secret": "very-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)

View File

@ -521,6 +521,7 @@ def test_get_oauth2_config(app_context: None) -> None:
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:USERADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
"request_content_type": "json",
}