feat: purge OAuth2 tokens when DB changes (#31164)
This commit is contained in:
parent
f077323e6f
commit
68499a1199
|
|
@ -78,6 +78,10 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
# since they're name based
|
||||
original_database_name = self._model.database_name
|
||||
|
||||
# Depending on the changes to the OAuth2 configuration we may need to purge
|
||||
# existing personal tokens.
|
||||
self._handle_oauth2()
|
||||
|
||||
database = DatabaseDAO.update(self._model, self._properties)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
ssh_tunnel = self._handle_ssh_tunnel(database)
|
||||
|
|
@ -88,6 +92,34 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
|
||||
return database
|
||||
|
||||
def _handle_oauth2(self) -> None:
|
||||
"""
|
||||
Handle changes in OAuth2.
|
||||
"""
|
||||
if not self._model:
|
||||
return
|
||||
|
||||
current_config = self._model.get_oauth2_config()
|
||||
if not current_config:
|
||||
return
|
||||
|
||||
new_config = self._properties["encrypted_extra"].get("oauth2_client_info", {})
|
||||
|
||||
# Keys that require purging personal tokens because they probably are no longer
|
||||
# valid. For example, if the scope has changed the existing tokens are still
|
||||
# associated with the old scope. Similarly, if the endpoints changed the tokens
|
||||
# are probably no longer valid.
|
||||
keys = {
|
||||
"id",
|
||||
"scope",
|
||||
"authorization_request_uri",
|
||||
"token_request_uri",
|
||||
}
|
||||
for key in keys:
|
||||
if current_config.get(key) != new_config.get(key):
|
||||
self._model.purge_oauth2_tokens()
|
||||
break
|
||||
|
||||
def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
|
||||
"""
|
||||
Delete, create, or update an SSH tunnel.
|
||||
|
|
|
|||
|
|
@ -1543,9 +1543,12 @@ PREFERRED_DATABASES: list[str] = [
|
|||
# one here.
|
||||
TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)
|
||||
|
||||
# Details needed for databases that allows user to authenticate using personal
|
||||
# OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more
|
||||
# information. The scope and URIs are optional.
|
||||
# Details needed for databases that allows user to authenticate using personal OAuth2
|
||||
# tokens. See https://github.com/apache/superset/issues/20300 for more information. The
|
||||
# scope and URIs are usually optional.
|
||||
# NOTE that if you change the id, scope, or URIs in this file, you probably need to purge
|
||||
# the existing tokens from the database. This needs to be done by running a query to
|
||||
# delete the existing tokens.
|
||||
DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = {
|
||||
# "Google Sheets": {
|
||||
# "id": "XXX.apps.googleusercontent.com",
|
||||
|
|
@ -1561,14 +1564,17 @@ DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = {
|
|||
# "token_request_uri": "https://oauth2.googleapis.com/token",
|
||||
# },
|
||||
}
|
||||
|
||||
# OAuth2 state is encoded in a JWT using the alogorithm below.
|
||||
DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
|
||||
|
||||
# By default the redirect URI points to /api/v1/database/oauth2/ and doesn't have to be
|
||||
# specified. If you're running multiple Superset instances you might want to have a
|
||||
# proxy handling the redirects, since redirect URIs need to be registered in the OAuth2
|
||||
# applications. In that case, the proxy can forward the request to the correct instance
|
||||
# by looking at the `default_redirect_uri` attribute in the OAuth2 state object.
|
||||
# DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/"
|
||||
|
||||
# Timeout when fetching access and refresh tokens.
|
||||
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ from sqlalchemy.pool import NullPool
|
|||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import ColumnElement, expression, Select
|
||||
|
||||
from superset import app, db_engine_specs, is_feature_enabled
|
||||
from superset import app, db, db_engine_specs, is_feature_enabled
|
||||
from superset.commands.database.exceptions import DatabaseInvalidError
|
||||
from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
|
@ -1136,6 +1136,18 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
"""
|
||||
return self.db_engine_spec.start_oauth2_dance(self)
|
||||
|
||||
def purge_oauth2_tokens(self) -> None:
|
||||
"""
|
||||
Delete all OAuth2 tokens associated with this database.
|
||||
|
||||
This is needed when the configuration changes. For example, a new client ID and
|
||||
secret probably will require new tokens. The same is valid for changes in the
|
||||
scope or in the endpoints.
|
||||
"""
|
||||
db.session.query(DatabaseUserOAuth2Tokens).filter(
|
||||
DatabaseUserOAuth2Tokens.id == self.id
|
||||
).delete()
|
||||
|
||||
|
||||
sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
|
||||
sqla.event.listen(Database, "after_update", security_manager.database_after_update)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,16 @@ from superset.commands.database.update import UpdateDatabaseCommand
|
|||
from superset.exceptions import OAuth2RedirectError
|
||||
from superset.extensions import security_manager
|
||||
|
||||
oauth2_client_info = {
|
||||
"id": "client_id",
|
||||
"secret": "client_secret",
|
||||
"scope": "scope-a",
|
||||
"redirect_uri": "redirect_uri",
|
||||
"authorization_request_uri": "auth_uri",
|
||||
"token_request_uri": "token_uri",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
||||
|
|
@ -72,6 +82,7 @@ def database_needs_oauth2(mocker: MockerFixture) -> MagicMock:
|
|||
"tab_id",
|
||||
"redirect_uri",
|
||||
)
|
||||
database.get_oauth2_config.return_value = oauth2_client_info
|
||||
|
||||
return database
|
||||
|
||||
|
|
@ -321,6 +332,47 @@ def test_update_with_oauth2(
|
|||
"add_permission_view_menu",
|
||||
)
|
||||
|
||||
database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = {
|
||||
"oauth2_client_info": oauth2_client_info,
|
||||
}
|
||||
|
||||
UpdateDatabaseCommand(1, {}).run()
|
||||
|
||||
add_permission_view_menu.assert_not_called()
|
||||
database_needs_oauth2.purge_oauth2_tokens.assert_not_called()
|
||||
|
||||
|
||||
def test_update_with_oauth2_changed(
|
||||
mocker: MockerFixture,
|
||||
database_needs_oauth2: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that the database can be updated even if OAuth2 is needed to connect.
|
||||
"""
|
||||
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
|
||||
DatabaseDAO.find_by_id.return_value = database_needs_oauth2
|
||||
DatabaseDAO.update.return_value = database_needs_oauth2
|
||||
|
||||
find_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"find_permission_view_menu",
|
||||
)
|
||||
find_permission_view_menu.side_effect = [
|
||||
None, # schema1 has no permissions
|
||||
"[my_db].[schema2]", # second schema already exists
|
||||
]
|
||||
add_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"add_permission_view_menu",
|
||||
)
|
||||
|
||||
modified_oauth2_client_info = oauth2_client_info.copy()
|
||||
modified_oauth2_client_info["scope"] = "scope-b"
|
||||
database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = {
|
||||
"oauth2_client_info": modified_oauth2_client_info,
|
||||
}
|
||||
|
||||
UpdateDatabaseCommand(1, {}).run()
|
||||
|
||||
add_permission_view_menu.assert_not_called()
|
||||
database_needs_oauth2.purge_oauth2_tokens.assert_called()
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import pytest
|
|||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.errors import SupersetErrorType
|
||||
|
|
@ -603,3 +604,82 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
|
|||
source=None,
|
||||
sqlalchemy_uri="trino://",
|
||||
)
|
||||
|
||||
|
||||
def test_purge_oauth2_tokens(session: Session) -> None:
|
||||
"""
|
||||
Test the `purge_oauth2_tokens` method.
|
||||
"""
|
||||
from flask_appbuilder.security.sqla.models import Role, User # noqa: F401
|
||||
|
||||
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
|
||||
user = User(
|
||||
first_name="Alice",
|
||||
last_name="Doe",
|
||||
email="adoe@example.org",
|
||||
username="adoe",
|
||||
)
|
||||
session.add(user)
|
||||
session.flush()
|
||||
|
||||
database1 = Database(database_name="my_oauth2_db", sqlalchemy_uri="sqlite://")
|
||||
database2 = Database(database_name="my_other_oauth2_db", sqlalchemy_uri="sqlite://")
|
||||
session.add_all([database1, database2])
|
||||
session.flush()
|
||||
|
||||
tokens = [
|
||||
DatabaseUserOAuth2Tokens(
|
||||
user_id=user.id,
|
||||
database_id=database1.id,
|
||||
access_token="my_access_token",
|
||||
access_token_expiration=datetime(2023, 1, 1),
|
||||
refresh_token="my_refresh_token",
|
||||
),
|
||||
DatabaseUserOAuth2Tokens(
|
||||
user_id=user.id,
|
||||
database_id=database2.id,
|
||||
access_token="my_other_access_token",
|
||||
access_token_expiration=datetime(2024, 1, 1),
|
||||
refresh_token="my_other_refresh_token",
|
||||
),
|
||||
]
|
||||
session.add_all(tokens)
|
||||
session.flush()
|
||||
|
||||
assert len(session.query(DatabaseUserOAuth2Tokens).all()) == 2
|
||||
|
||||
token = (
|
||||
session.query(DatabaseUserOAuth2Tokens)
|
||||
.filter_by(database_id=database1.id)
|
||||
.one()
|
||||
)
|
||||
assert token.user_id == user.id
|
||||
assert token.database_id == database1.id
|
||||
assert token.access_token == "my_access_token"
|
||||
assert token.access_token_expiration == datetime(2023, 1, 1)
|
||||
assert token.refresh_token == "my_refresh_token"
|
||||
|
||||
database1.purge_oauth2_tokens()
|
||||
|
||||
# confirm token was deleted
|
||||
token = (
|
||||
session.query(DatabaseUserOAuth2Tokens)
|
||||
.filter_by(database_id=database1.id)
|
||||
.one_or_none()
|
||||
)
|
||||
assert token is None
|
||||
|
||||
# make sure other DB tokens weren't deleted
|
||||
token = (
|
||||
session.query(DatabaseUserOAuth2Tokens)
|
||||
.filter_by(database_id=database2.id)
|
||||
.one()
|
||||
)
|
||||
assert token is not None
|
||||
|
||||
# make sure database was not deleted... just in case
|
||||
database = session.query(Database).filter_by(id=database1.id).one()
|
||||
assert database.name == "my_oauth2_db"
|
||||
|
|
|
|||
Loading…
Reference in New Issue