From 305b6df6e3e5aaa6d3faa8fa1a2d91fcb05b7c34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ferr=C3=A3o?= Date: Mon, 4 Nov 2024 17:54:47 +0100 Subject: [PATCH] feat(oauth2): add support for trino (#30081) --- superset/db_engine_specs/base.py | 46 +++---- superset/db_engine_specs/trino.py | 27 +++- superset/superset_typing.py | 4 + superset/utils/oauth2.py | 7 +- .../db_engine_specs/test_gsheets.py | 1 + .../unit_tests/db_engine_specs/test_trino.py | 127 +++++++++++++----- tests/unit_tests/models/core_test.py | 1 + 7 files changed, 156 insertions(+), 57 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index a086f6eff..8cabb1e58 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -433,6 +433,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods oauth2_scope = "" oauth2_authorization_request_uri: str | None = None # pylint: disable=invalid-name oauth2_token_request_uri: str | None = None + oauth2_token_request_type = "data" # Driver-specific exception that should be mapped to OAuth2RedirectError oauth2_exception = OAuth2RedirectError @@ -525,6 +526,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "token_request_uri", cls.oauth2_token_request_uri, ), + "request_content_type": db_engine_spec_config.get( + "request_content_type", cls.oauth2_token_request_type + ), } return config @@ -562,18 +566,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() uri = config["token_request_uri"] - response = requests.post( - uri, - json={ - "code": code, - "client_id": config["id"], - "client_secret": config["secret"], - "redirect_uri": config["redirect_uri"], - "grant_type": "authorization_code", - }, - timeout=timeout, - ) - return response.json() + req_body = { + "code": code, + "client_id": config["id"], + "client_secret": config["secret"], + "redirect_uri": config["redirect_uri"], + "grant_type": "authorization_code", + } + if config["request_content_type"] == "data": + return requests.post(uri, data=req_body, timeout=timeout).json() + return requests.post(uri, json=req_body, timeout=timeout).json() @classmethod def get_oauth2_fresh_token( @@ -586,17 +588,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() uri = config["token_request_uri"] - response = requests.post( - uri, - json={ - "client_id": config["id"], - "client_secret": config["secret"], - "refresh_token": refresh_token, - "grant_type": "refresh_token", - }, - timeout=timeout, - ) - return response.json() + req_body = { + "client_id": config["id"], + "client_secret": config["secret"], + "refresh_token": refresh_token, + "grant_type": "refresh_token", + } + if config["request_content_type"] == "data": + return requests.post(uri, data=req_body, timeout=timeout).json() + return requests.post(uri, json=req_body, timeout=timeout).json() @classmethod def get_allows_alias_in_select( diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index c47352821..ad00557f6 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -27,11 +27,13 @@ from typing import Any, TYPE_CHECKING import numpy as np import pandas as pd import pyarrow as pa -from flask import ctx, current_app, Flask, g +import requests +from flask import copy_current_request_context, ctx, current_app, Flask, g from sqlalchemy import text from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.exc import NoSuchTableError +from trino.exceptions import HttpError from superset import db from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT @@ -60,11 +62,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class CustomTrinoAuthErrorMeta(type): + def __instancecheck__(cls, instance: object) -> bool: + logger.info("is this being called?") + return isinstance( + instance, HttpError + ) and "error 401: b'Invalid credentials'" in str(instance) + + +class TrinoAuthError(HttpError, metaclass=CustomTrinoAuthErrorMeta): + pass + + class TrinoEngineSpec(PrestoBaseEngineSpec): engine = "trino" engine_name = "Trino" allows_alias_to_source_column = False + # OAuth 2.0 + supports_oauth2 = True + oauth2_exception = TrinoAuthError + oauth2_token_request_type = "data" + @classmethod def get_extra_table_metadata( cls, @@ -142,6 +161,10 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): # Set principal_username=$effective_username if backend_name == "trino" and username is not None: connect_args["user"] = username + if access_token is not None: + http_session = requests.Session() + http_session.headers.update({"Authorization": f"Bearer {access_token}"}) + connect_args["http_session"] = http_session @classmethod def get_url_for_impersonation( @@ -154,6 +177,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): """ Return a modified URL with the username set. + :param access_token: Personal access token for OAuth2 :param url: SQLAlchemy URL object :param impersonate_user: Flag indicating if impersonation is enabled :param username: Effective username @@ -228,6 +252,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): execute_result: dict[str, Any] = {} execute_event = threading.Event() + @copy_current_request_context def _execute( results: dict[str, Any], event: threading.Event, diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 3a850e0ac..c3c40cd31 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -149,6 +149,10 @@ class OAuth2ClientConfig(TypedDict): # expired access token. token_request_uri: str + # Not all identity providers expect json. Keycloak expects a form encoded request, + # which in the `requests` package context means using the `data` param, not `json`. + request_content_type: str + class OAuth2TokenResponse(TypedDict, total=False): """ diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index b889ef83c..95db2921f 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -23,7 +23,7 @@ from typing import Any, TYPE_CHECKING import backoff import jwt from flask import current_app, url_for -from marshmallow import EXCLUDE, fields, post_load, Schema +from marshmallow import EXCLUDE, fields, post_load, Schema, validate from superset import db from superset.distributed_lock import KeyValueDistributedLock @@ -192,3 +192,8 @@ class OAuth2ClientConfigSchema(Schema): ) authorization_request_uri = fields.String(required=True) token_request_uri = fields.String(required=True) + request_content_type = fields.String( + required=False, + load_default=lambda: "json", + validate=validate.OneOf(["json", "data"]), + ) diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index 5d2ddb807..4e17054db 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -559,6 +559,7 @@ def oauth2_config() -> OAuth2ClientConfig: "redirect_uri": "http://localhost:8088/api/v1/oauth2/", "authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth", "token_request_uri": "https://oauth2.googleapis.com/token", + "request_content_type": "json", } diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 5a32cd050..b616adfcf 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -45,7 +45,12 @@ from superset.db_engine_specs.exceptions import ( SupersetDBAPIProgrammingError, ) from superset.sql_parse import Table -from superset.superset_typing import ResultSetColumnType, SQLAColumnType, SQLType +from superset.superset_typing import ( + OAuth2ClientConfig, + ResultSetColumnType, + SQLAColumnType, + SQLType, +) from superset.utils import json from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -421,38 +426,9 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture): def _mock_execute(*args, **kwargs): mock_cursor.query_id = query_id - mock_cursor.execute.side_effect = _mock_execute - with patch.dict( - "superset.config.DISALLOWED_SQL_FUNCTIONS", - {}, - clear=True, - ): - TrinoEngineSpec.execute_with_cursor( - cursor=mock_cursor, - sql="SELECT 1 FROM foo", - query=mock_query, - ) + with app.test_request_context("/some/place/"): + mock_cursor.execute.side_effect = _mock_execute - mock_query.set_extra_json_key.assert_called_once_with( - key=QUERY_CANCEL_KEY, value=query_id - ) - - -def test_execute_with_cursor_app_context(app, mocker: MockerFixture): - """Test that `execute_with_cursor` still contains the current app context""" - from superset.db_engine_specs.trino import TrinoEngineSpec - - mock_cursor = mocker.MagicMock() - mock_cursor.query_id = None - - mock_query = mocker.MagicMock() - g.some_value = "some_value" - - def _mock_execute(*args, **kwargs): - assert has_app_context() - assert g.some_value == "some_value" - - with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute): with patch.dict( "superset.config.DISALLOWED_SQL_FUNCTIONS", {}, @@ -464,6 +440,39 @@ def test_execute_with_cursor_app_context(app, mocker: MockerFixture): query=mock_query, ) + mock_query.set_extra_json_key.assert_called_once_with( + key=QUERY_CANCEL_KEY, value=query_id + ) + + +def test_execute_with_cursor_app_context(app, mocker: MockerFixture): + """Test that `execute_with_cursor` still contains the current app context""" + from superset.db_engine_specs.trino import TrinoEngineSpec + + mock_cursor = mocker.MagicMock() + mock_cursor.query_id = None + + mock_query = mocker.MagicMock() + + def _mock_execute(*args, **kwargs): + assert has_app_context() + assert g.some_value == "some_value" + + with app.test_request_context("/some/place/"): + g.some_value = "some_value" + + with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute): + with patch.dict( + "superset.config.DISALLOWED_SQL_FUNCTIONS", + {}, + clear=True, + ): + TrinoEngineSpec.execute_with_cursor( + cursor=mock_cursor, + sql="SELECT 1 FROM foo", + query=mock_query, + ) + def test_get_columns(mocker: MockerFixture): """Test that ROW columns are not expanded without expand_rows""" @@ -784,3 +793,57 @@ def test_where_latest_partition( ) == f"""SELECT * FROM table \nWHERE partition_key = {expected_value}""" ) + + +@pytest.fixture +def oauth2_config() -> OAuth2ClientConfig: + """ + Config for Trino OAuth2. + """ + return { + "id": "trino", + "secret": "very-secret", + "scope": "", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://trino.auth.server.example/realms/master/protocol/openid-connect/auth", + "token_request_uri": "https://trino.auth.server.example/master/protocol/openid-connect/token", + "request_content_type": "data", + } + + +def test_get_oauth2_token( + mocker: MockerFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: + """ + Test `get_oauth2_token`. + """ + from superset.db_engine_specs.trino import TrinoEngineSpec + + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().json.return_value = { + "access_token": "access-token", + "expires_in": 3600, + "scope": "scope", + "token_type": "Bearer", + "refresh_token": "refresh-token", + } + + assert TrinoEngineSpec.get_oauth2_token(oauth2_config, "code") == { + "access_token": "access-token", + "expires_in": 3600, + "scope": "scope", + "token_type": "Bearer", + "refresh_token": "refresh-token", + } + requests.post.assert_called_with( + "https://trino.auth.server.example/master/protocol/openid-connect/token", + data={ + "code": "code", + "client_id": "trino", + "client_secret": "very-secret", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "grant_type": "authorization_code", + }, + timeout=30.0, + ) diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 452cbb6f5..1dff4784e 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -521,6 +521,7 @@ def test_get_oauth2_config(app_context: None) -> None: "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", "scope": "refresh_token session:role:USERADMIN", "redirect_uri": "http://example.com/api/v1/database/oauth2/", + "request_content_type": "json", }