fix: move oauth2 capture to `get_sqla_engine` (#32137)

This commit is contained in:
Beto Dealmeida 2025-02-04 18:24:05 -05:00 committed by GitHub
parent c64018d421
commit c7c3b1b0e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 123 additions and 24 deletions

View File

@ -84,7 +84,11 @@ from superset.superset_typing import (
from superset.utils import cache as cache_util, core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import get_username
from superset.utils.oauth2 import get_oauth2_access_token, OAuth2ClientConfigSchema
from superset.utils.oauth2 import (
check_for_oauth2,
get_oauth2_access_token,
OAuth2ClientConfigSchema,
)
config = app.config
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
@ -451,13 +455,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
with engine_context_manager(self, catalog, schema):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
with check_for_oauth2(self):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
self,
@ -583,10 +588,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
nullpool=nullpool,
source=source,
) as engine:
try:
with check_for_oauth2(self):
with closing(engine.raw_connection()) as conn:
# pre-session queries are used to set the selected schema and, in the # noqa: E501
# future, the selected catalog
# pre-session queries are used to set the selected catalog/schema
for prequery in self.db_engine_spec.get_prequeries(
database=self,
catalog=catalog,
@ -597,11 +601,6 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
yield conn
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.db_engine_spec.start_oauth2_dance(self)
raise
def get_default_catalog(self) -> str | None:
"""
Return the default configured catalog for the database.

View File

@ -17,8 +17,9 @@
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from typing import Any, TYPE_CHECKING
from typing import Any, Iterator, TYPE_CHECKING
import backoff
import jwt
@ -32,7 +33,7 @@ from superset.superset_typing import OAuth2ClientConfig, OAuth2State
if TYPE_CHECKING:
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import DatabaseUserOAuth2Tokens
from superset.models.core import Database, DatabaseUserOAuth2Tokens
JWT_EXPIRATION = timedelta(minutes=5)
@ -197,3 +198,16 @@ class OAuth2ClientConfigSchema(Schema):
load_default=lambda: "json",
validate=validate.OneOf(["json", "data"]),
)
@contextmanager
def check_for_oauth2(database: Database) -> Iterator[None]:
"""
Run code and check if OAuth2 is needed.
"""
try:
yield
except Exception as ex:
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(ex):
database.db_engine_spec.start_oauth2_dance(database)
raise

View File

@ -558,16 +558,47 @@ def test_get_oauth2_config(app_context: None) -> None:
}
def test_raw_connection_oauth(mocker: MockerFixture) -> None:
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
Some databases that use OAuth2 need to trigger the flow when the connection is
created, rather than when the query runs. This happens when the SQLAlchemy engine
URI cannot be built without the user personal token.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This test verifies that the exception is captured and raised correctly so that the
frontend can trigger the OAuth2 dance.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the engine is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42
database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
conn.cursor()
assert str(excinfo.value) == "You don't have permission to access the data."
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the connection is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
@ -591,6 +622,40 @@ def test_raw_connection_oauth(mocker: MockerFixture) -> None:
assert str(excinfo.value) == "You don't have permission to access the data."
def test_raw_connection_oauth_execute(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the connection is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42
database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
get_sqla_engine().__enter__().raw_connection().cursor().execute.side_effect = (
OAuth2Error("OAuth2 required")
)
with pytest.raises(OAuth2RedirectError) as excinfo: # noqa: PT012
with database.get_raw_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
assert str(excinfo.value) == "You don't have permission to access the data."
def test_get_schema_access_for_file_upload() -> None:
"""
Test the `get_schema_access_for_file_upload` method.
@ -638,6 +703,27 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
)
def test_engine_oauth2(mocker: MockerFixture) -> None:
"""
Test that we handle OAuth2 when `create_engine` fails.
"""
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
start_oauth2_dance = mocker.patch.object(
database.db_engine_spec,
"start_oauth2_dance",
side_effect=OAuth2Error("OAuth2 required"),
)
with pytest.raises(OAuth2Error):
with database.get_sqla_engine("catalog", "schema"):
pass
start_oauth2_dance.assert_called_with(database)
def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.