fix: move oauth2 capture to `get_sqla_engine` (#32137)
This commit is contained in:
parent
c64018d421
commit
c7c3b1b0e9
|
|
@ -84,7 +84,11 @@ from superset.superset_typing import (
|
||||||
from superset.utils import cache as cache_util, core as utils, json
|
from superset.utils import cache as cache_util, core as utils, json
|
||||||
from superset.utils.backports import StrEnum
|
from superset.utils.backports import StrEnum
|
||||||
from superset.utils.core import get_username
|
from superset.utils.core import get_username
|
||||||
from superset.utils.oauth2 import get_oauth2_access_token, OAuth2ClientConfigSchema
|
from superset.utils.oauth2 import (
|
||||||
|
check_for_oauth2,
|
||||||
|
get_oauth2_access_token,
|
||||||
|
OAuth2ClientConfigSchema,
|
||||||
|
)
|
||||||
|
|
||||||
config = app.config
|
config = app.config
|
||||||
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
||||||
|
|
@ -451,13 +455,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||||
|
|
||||||
engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
|
engine_context_manager = config["ENGINE_CONTEXT_MANAGER"]
|
||||||
with engine_context_manager(self, catalog, schema):
|
with engine_context_manager(self, catalog, schema):
|
||||||
yield self._get_sqla_engine(
|
with check_for_oauth2(self):
|
||||||
catalog=catalog,
|
yield self._get_sqla_engine(
|
||||||
schema=schema,
|
catalog=catalog,
|
||||||
nullpool=nullpool,
|
schema=schema,
|
||||||
source=source,
|
nullpool=nullpool,
|
||||||
sqlalchemy_uri=sqlalchemy_uri,
|
source=source,
|
||||||
)
|
sqlalchemy_uri=sqlalchemy_uri,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
|
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
|
||||||
self,
|
self,
|
||||||
|
|
@ -583,10 +588,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||||
nullpool=nullpool,
|
nullpool=nullpool,
|
||||||
source=source,
|
source=source,
|
||||||
) as engine:
|
) as engine:
|
||||||
try:
|
with check_for_oauth2(self):
|
||||||
with closing(engine.raw_connection()) as conn:
|
with closing(engine.raw_connection()) as conn:
|
||||||
# pre-session queries are used to set the selected schema and, in the # noqa: E501
|
# pre-session queries are used to set the selected catalog/schema
|
||||||
# future, the selected catalog
|
|
||||||
for prequery in self.db_engine_spec.get_prequeries(
|
for prequery in self.db_engine_spec.get_prequeries(
|
||||||
database=self,
|
database=self,
|
||||||
catalog=catalog,
|
catalog=catalog,
|
||||||
|
|
@ -597,11 +601,6 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||||
|
|
||||||
yield conn
|
yield conn
|
||||||
|
|
||||||
except Exception as ex:
|
|
||||||
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
|
|
||||||
self.db_engine_spec.start_oauth2_dance(self)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_default_catalog(self) -> str | None:
|
def get_default_catalog(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Return the default configured catalog for the database.
|
Return the default configured catalog for the database.
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,9 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import Any, Iterator, TYPE_CHECKING
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import jwt
|
import jwt
|
||||||
|
|
@ -32,7 +33,7 @@ from superset.superset_typing import OAuth2ClientConfig, OAuth2State
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from superset.db_engine_specs.base import BaseEngineSpec
|
from superset.db_engine_specs.base import BaseEngineSpec
|
||||||
from superset.models.core import DatabaseUserOAuth2Tokens
|
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||||
|
|
||||||
JWT_EXPIRATION = timedelta(minutes=5)
|
JWT_EXPIRATION = timedelta(minutes=5)
|
||||||
|
|
||||||
|
|
@ -197,3 +198,16 @@ class OAuth2ClientConfigSchema(Schema):
|
||||||
load_default=lambda: "json",
|
load_default=lambda: "json",
|
||||||
validate=validate.OneOf(["json", "data"]),
|
validate=validate.OneOf(["json", "data"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def check_for_oauth2(database: Database) -> Iterator[None]:
|
||||||
|
"""
|
||||||
|
Run code and check if OAuth2 is needed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except Exception as ex:
|
||||||
|
if database.is_oauth2_enabled() and database.db_engine_spec.needs_oauth2(ex):
|
||||||
|
database.db_engine_spec.start_oauth2_dance(database)
|
||||||
|
raise
|
||||||
|
|
|
||||||
|
|
@ -558,16 +558,47 @@ def test_get_oauth2_config(app_context: None) -> None:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_raw_connection_oauth(mocker: MockerFixture) -> None:
|
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
|
||||||
"""
|
"""
|
||||||
Test that we can start OAuth2 from `raw_connection()` errors.
|
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||||
|
|
||||||
Some databases that use OAuth2 need to trigger the flow when the connection is
|
With OAuth2, some databases will raise an exception when the engine is first created
|
||||||
created, rather than when the query runs. This happens when the SQLAlchemy engine
|
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
|
||||||
URI cannot be built without the user personal token.
|
finally, GSheets will raise an exception when the query is executed.
|
||||||
|
|
||||||
This test verifies that the exception is captured and raised correctly so that the
|
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
|
||||||
frontend can trigger the OAuth2 dance.
|
triggered when the engine is created.
|
||||||
|
"""
|
||||||
|
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||||
|
g.user = mocker.MagicMock()
|
||||||
|
g.user.id = 42
|
||||||
|
|
||||||
|
database = Database(
|
||||||
|
id=1,
|
||||||
|
database_name="my_db",
|
||||||
|
sqlalchemy_uri="sqlite://",
|
||||||
|
encrypted_extra=json.dumps(oauth2_client_info),
|
||||||
|
)
|
||||||
|
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
|
||||||
|
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
|
||||||
|
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
|
||||||
|
|
||||||
|
with pytest.raises(OAuth2RedirectError) as excinfo:
|
||||||
|
with database.get_raw_connection() as conn:
|
||||||
|
conn.cursor()
|
||||||
|
assert str(excinfo.value) == "You don't have permission to access the data."
|
||||||
|
|
||||||
|
|
||||||
|
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||||
|
|
||||||
|
With OAuth2, some databases will raise an exception when the engine is first created
|
||||||
|
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
|
||||||
|
finally, GSheets will raise an exception when the query is executed.
|
||||||
|
|
||||||
|
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
|
||||||
|
triggered when the connection is created.
|
||||||
"""
|
"""
|
||||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||||
g.user = mocker.MagicMock()
|
g.user = mocker.MagicMock()
|
||||||
|
|
@ -591,6 +622,40 @@ def test_raw_connection_oauth(mocker: MockerFixture) -> None:
|
||||||
assert str(excinfo.value) == "You don't have permission to access the data."
|
assert str(excinfo.value) == "You don't have permission to access the data."
|
||||||
|
|
||||||
|
|
||||||
|
def test_raw_connection_oauth_execute(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||||
|
|
||||||
|
With OAuth2, some databases will raise an exception when the engine is first created
|
||||||
|
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
|
||||||
|
finally, GSheets will raise an exception when the query is executed.
|
||||||
|
|
||||||
|
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
|
||||||
|
triggered when the connection is created.
|
||||||
|
"""
|
||||||
|
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||||
|
g.user = mocker.MagicMock()
|
||||||
|
g.user.id = 42
|
||||||
|
|
||||||
|
database = Database(
|
||||||
|
id=1,
|
||||||
|
database_name="my_db",
|
||||||
|
sqlalchemy_uri="sqlite://",
|
||||||
|
encrypted_extra=json.dumps(oauth2_client_info),
|
||||||
|
)
|
||||||
|
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
|
||||||
|
get_sqla_engine = mocker.patch.object(database, "get_sqla_engine")
|
||||||
|
get_sqla_engine().__enter__().raw_connection().cursor().execute.side_effect = (
|
||||||
|
OAuth2Error("OAuth2 required")
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(OAuth2RedirectError) as excinfo: # noqa: PT012
|
||||||
|
with database.get_raw_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT 1")
|
||||||
|
assert str(excinfo.value) == "You don't have permission to access the data."
|
||||||
|
|
||||||
|
|
||||||
def test_get_schema_access_for_file_upload() -> None:
|
def test_get_schema_access_for_file_upload() -> None:
|
||||||
"""
|
"""
|
||||||
Test the `get_schema_access_for_file_upload` method.
|
Test the `get_schema_access_for_file_upload` method.
|
||||||
|
|
@ -638,6 +703,27 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_engine_oauth2(mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Test that we handle OAuth2 when `create_engine` fails.
|
||||||
|
"""
|
||||||
|
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||||
|
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
|
||||||
|
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
|
||||||
|
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
|
||||||
|
start_oauth2_dance = mocker.patch.object(
|
||||||
|
database.db_engine_spec,
|
||||||
|
"start_oauth2_dance",
|
||||||
|
side_effect=OAuth2Error("OAuth2 required"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(OAuth2Error):
|
||||||
|
with database.get_sqla_engine("catalog", "schema"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
start_oauth2_dance.assert_called_with(database)
|
||||||
|
|
||||||
|
|
||||||
def test_purge_oauth2_tokens(session: Session) -> None:
|
def test_purge_oauth2_tokens(session: Session) -> None:
|
||||||
"""
|
"""
|
||||||
Test the `purge_oauth2_tokens` method.
|
Test the `purge_oauth2_tokens` method.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue