feat(oauth2): add support for trino (#30081)

This commit is contained in:
João Ferrão 2024-11-04 17:54:47 +01:00 committed by GitHub
parent 64f8140731
commit 305b6df6e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 156 additions and 57 deletions

View File

@ -433,6 +433,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
oauth2_scope = "" oauth2_scope = ""
oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name
oauth2_token_request_uri: str | None = None oauth2_token_request_uri: str | None = None
oauth2_token_request_type = "data"
# Driver-specific exception that should be mapped to OAuth2RedirectError # Driver-specific exception that should be mapped to OAuth2RedirectError
oauth2_exception = OAuth2RedirectError oauth2_exception = OAuth2RedirectError
@ -525,6 +526,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"token_request_uri", "token_request_uri",
cls.oauth2_token_request_uri, cls.oauth2_token_request_uri,
), ),
"request_content_type": db_engine_spec_config.get(
"request_content_type", cls.oauth2_token_request_type
),
} }
return config return config
@ -562,18 +566,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
""" """
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"] uri = config["token_request_uri"]
response = requests.post( req_body = {
uri,
json={
"code": code, "code": code,
"client_id": config["id"], "client_id": config["id"],
"client_secret": config["secret"], "client_secret": config["secret"],
"redirect_uri": config["redirect_uri"], "redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code", "grant_type": "authorization_code",
}, }
timeout=timeout, if config["request_content_type"] == "data":
) return requests.post(uri, data=req_body, timeout=timeout).json()
return response.json() return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod @classmethod
def get_oauth2_fresh_token( def get_oauth2_fresh_token(
@ -586,17 +588,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
""" """
timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"] uri = config["token_request_uri"]
response = requests.post( req_body = {
uri,
json={
"client_id": config["id"], "client_id": config["id"],
"client_secret": config["secret"], "client_secret": config["secret"],
"refresh_token": refresh_token, "refresh_token": refresh_token,
"grant_type": "refresh_token", "grant_type": "refresh_token",
}, }
timeout=timeout, if config["request_content_type"] == "data":
) return requests.post(uri, data=req_body, timeout=timeout).json()
return response.json() return requests.post(uri, json=req_body, timeout=timeout).json()
@classmethod @classmethod
def get_allows_alias_in_select( def get_allows_alias_in_select(

View File

@ -27,11 +27,13 @@ from typing import Any, TYPE_CHECKING
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pyarrow as pa import pyarrow as pa
from flask import ctx, current_app, Flask, g import requests
from flask import copy_current_request_context, ctx, current_app, Flask, g
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError from sqlalchemy.exc import NoSuchTableError
from trino.exceptions import HttpError
from superset import db from superset import db
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
@ -60,11 +62,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CustomTrinoAuthErrorMeta(type):
def __instancecheck__(cls, instance: object) -> bool:
logger.info("is this being called?")
return isinstance(
instance, HttpError
) and "error 401: b'Invalid credentials'" in str(instance)
class TrinoAuthError(HttpError, metaclass=CustomTrinoAuthErrorMeta):
pass
class TrinoEngineSpec(PrestoBaseEngineSpec): class TrinoEngineSpec(PrestoBaseEngineSpec):
engine = "trino" engine = "trino"
engine_name = "Trino" engine_name = "Trino"
allows_alias_to_source_column = False allows_alias_to_source_column = False
# OAuth 2.0
supports_oauth2 = True
oauth2_exception = TrinoAuthError
oauth2_token_request_type = "data"
@classmethod @classmethod
def get_extra_table_metadata( def get_extra_table_metadata(
cls, cls,
@ -142,6 +161,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
# Set principal_username=$effective_username # Set principal_username=$effective_username
if backend_name == "trino" and username is not None: if backend_name == "trino" and username is not None:
connect_args["user"] = username connect_args["user"] = username
if access_token is not None:
http_session = requests.Session()
http_session.headers.update({"Authorization": f"Bearer {access_token}"})
connect_args["http_session"] = http_session
@classmethod @classmethod
def get_url_for_impersonation( def get_url_for_impersonation(
@ -154,6 +177,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
""" """
Return a modified URL with the username set. Return a modified URL with the username set.
:param access_token: Personal access token for OAuth2
:param url: SQLAlchemy URL object :param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled :param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username :param username: Effective username
@ -228,6 +252,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_result: dict[str, Any] = {} execute_result: dict[str, Any] = {}
execute_event = threading.Event() execute_event = threading.Event()
@copy_current_request_context
def _execute( def _execute(
results: dict[str, Any], results: dict[str, Any],
event: threading.Event, event: threading.Event,

View File

@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict):
# expired access token. # expired access token.
token_request_uri: str token_request_uri: str
# Not all identity providers expect json. Keycloak expects a form encoded request,
# which in the `requests` package context means using the `data` param, not `json`.
request_content_type: str
class OAuth2TokenResponse(TypedDict, total=False): class OAuth2TokenResponse(TypedDict, total=False):
""" """

View File

@ -23,7 +23,7 @@ from typing import Any, TYPE_CHECKING
import backoff import backoff
import jwt import jwt
from flask import current_app, url_for from flask import current_app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from superset import db from superset import db
from superset.distributed_lock import KeyValueDistributedLock from superset.distributed_lock import KeyValueDistributedLock
@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema):
) )
authorization_request_uri = fields.String(required=True) authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True) token_request_uri = fields.String(required=True)
request_content_type = fields.String(
required=False,
load_default=lambda: "json",
validate=validate.OneOf(["json", "data"]),
)

View File

@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig:
"redirect_uri": "http://localhost:8088/api/v1/oauth2/", "redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth", "authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth",
"token_request_uri": "https://oauth2.googleapis.com/token", "token_request_uri": "https://oauth2.googleapis.com/token",
"request_content_type": "json",
} }

View File

@ -45,7 +45,12 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIProgrammingError, SupersetDBAPIProgrammingError,
) )
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType from superset.superset_typing import (
OAuth2ClientConfig,
ResultSetColumnType,
SQLAColumnType,
SQLType,
)
from superset.utils import json from superset.utils import json
from superset.utils.core import GenericDataType from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import ( from tests.unit_tests.db_engine_specs.utils import (
@ -421,7 +426,9 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
def _mock_execute(*args, **kwargs): def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id mock_cursor.query_id = query_id
with app.test_request_context("/some/place/"):
mock_cursor.execute.side_effect = _mock_execute mock_cursor.execute.side_effect = _mock_execute
with patch.dict( with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS", "superset.config.DISALLOWED_SQL_FUNCTIONS",
{}, {},
@ -446,12 +453,14 @@ def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
mock_cursor.query_id = None mock_cursor.query_id = None
mock_query = mocker.MagicMock() mock_query = mocker.MagicMock()
g.some_value = "some_value"
def _mock_execute(*args, **kwargs): def _mock_execute(*args, **kwargs):
assert has_app_context() assert has_app_context()
assert g.some_value == "some_value" assert g.some_value == "some_value"
with app.test_request_context("/some/place/"):
g.some_value = "some_value"
with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute): with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict( with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS", "superset.config.DISALLOWED_SQL_FUNCTIONS",
@ -784,3 +793,57 @@ def test_where_latest_partition(
) )
== f"""SELECT * FROM table \nWHERE partition_key = {expected_value}""" == f"""SELECT * FROM table \nWHERE partition_key = {expected_value}"""
) )
@pytest.fixture
def oauth2_config() -> OAuth2ClientConfig:
"""
Config for Trino OAuth2.
"""
return {
"id": "trino",
"secret": "very-secret",
"scope": "",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth",
"token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token",
"request_content_type": "data",
}
def test_get_oauth2_token(
mocker: MockerFixture,
oauth2_config: OAuth2ClientConfig,
) -> None:
"""
Test `get_oauth2_token`.
"""
from superset.db_engine_specs.trino import TrinoEngineSpec
requests = mocker.patch("superset.db_engine_specs.base.requests")
requests.post().json.return_value = {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == {
"access_token": "access-token",
"expires_in": 3600,
"scope": "scope",
"token_type": "Bearer",
"refresh_token": "refresh-token",
}
requests.post.assert_called_with(
"https://trino.auth.server.example/master/protocol/openid-connect/token",
data={
"code": "code",
"client_id": "trino",
"client_secret": "very-secret",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"grant_type": "authorization_code",
},
timeout=30.0,
)

View File

@ -521,6 +521,7 @@ def test_get_oauth2_config(app_context: None) -> None:
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:USERADMIN", "scope": "refresh_token session:role:USERADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/", "redirect_uri": "http://example.com/api/v1/database/oauth2/",
"request_content_type": "json",
} }