feat: purge OAuth2 tokens when DB changes (#31164)

This commit is contained in:
Beto Dealmeida 2024-11-26 15:57:01 -05:00 committed by GitHub
parent f077323e6f
commit 68499a1199
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 186 additions and 4 deletions

View File

@ -78,6 +78,10 @@ class UpdateDatabaseCommand(BaseCommand):
# since they're name based
original_database_name = self._model.database_name
# Depending on the changes to the OAuth2 configuration we may need to purge
# existing personal tokens.
self._handle_oauth2()
database = DatabaseDAO.update(self._model, self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
@ -88,6 +92,34 @@ class UpdateDatabaseCommand(BaseCommand):
return database
def _handle_oauth2(self) -> None:
"""
Handle changes in OAuth2.
"""
if not self._model:
return
current_config = self._model.get_oauth2_config()
if not current_config:
return
new_config = self._properties["encrypted_extra"].get("oauth2_client_info", {})
# Keys that require purging personal tokens because they probably are no longer
# valid. For example, if the scope has changed the existing tokens are still
# associated with the old scope. Similarly, if the endpoints changed the tokens
# are probably no longer valid.
keys = {
"id",
"scope",
"authorization_request_uri",
"token_request_uri",
}
for key in keys:
if current_config.get(key) != new_config.get(key):
self._model.purge_oauth2_tokens()
break
def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
"""
Delete, create, or update an SSH tunnel.

View File

@ -1543,9 +1543,12 @@ PREFERRED_DATABASES: list[str] = [
# one here.
TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30)
# Details needed for databases that allows user to authenticate using personal
# OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more
# information. The scope and URIs are optional.
# Details needed for databases that allows user to authenticate using personal OAuth2
# tokens. See https://github.com/apache/superset/issues/20300 for more information. The
# scope and URIs are usually optional.
# NOTE that if you change the id, scope, or URIs in this file, you probably need to purge
# the existing tokens from the database. This needs to be done by running a query to
# delete the existing tokens.
DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = {
# "Google Sheets": {
# "id": "XXX.apps.googleusercontent.com",
@ -1561,14 +1564,17 @@ DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = {
# "token_request_uri": "https://oauth2.googleapis.com/token",
# },
}
# OAuth2 state is encoded in a JWT using the alogorithm below.
DATABASE_OAUTH2_JWT_ALGORITHM = "HS256"
# By default the redirect URI points to /api/v1/database/oauth2/ and doesn't have to be
# specified. If you're running multiple Superset instances you might want to have a
# proxy handling the redirects, since redirect URIs need to be registered in the OAuth2
# applications. In that case, the proxy can forward the request to the correct instance
# by looking at the `default_redirect_uri` attribute in the OAuth2 state object.
# DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/"
# Timeout when fetching access and refresh tokens.
DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30)

View File

@ -60,7 +60,7 @@ from sqlalchemy.pool import NullPool
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import ColumnElement, expression, Select
from superset import app, db_engine_specs, is_feature_enabled
from superset import app, db, db_engine_specs, is_feature_enabled
from superset.commands.database.exceptions import DatabaseInvalidError
from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK
from superset.databases.utils import make_url_safe
@ -1136,6 +1136,18 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
"""
return self.db_engine_spec.start_oauth2_dance(self)
def purge_oauth2_tokens(self) -> None:
"""
Delete all OAuth2 tokens associated with this database.
This is needed when the configuration changes. For example, a new client ID and
secret probably will require new tokens. The same is valid for changes in the
scope or in the endpoints.
"""
db.session.query(DatabaseUserOAuth2Tokens).filter(
DatabaseUserOAuth2Tokens.id == self.id
).delete()
sqla.event.listen(Database, "after_insert", security_manager.database_after_insert)
sqla.event.listen(Database, "after_update", security_manager.database_after_update)

View File

@ -24,6 +24,16 @@ from superset.commands.database.update import UpdateDatabaseCommand
from superset.exceptions import OAuth2RedirectError
from superset.extensions import security_manager
oauth2_client_info = {
"id": "client_id",
"secret": "client_secret",
"scope": "scope-a",
"redirect_uri": "redirect_uri",
"authorization_request_uri": "auth_uri",
"token_request_uri": "token_uri",
"request_content_type": "json",
}
@pytest.fixture()
def database_with_catalog(mocker: MockerFixture) -> MagicMock:
@ -72,6 +82,7 @@ def database_needs_oauth2(mocker: MockerFixture) -> MagicMock:
"tab_id",
"redirect_uri",
)
database.get_oauth2_config.return_value = oauth2_client_info
return database
@ -321,6 +332,47 @@ def test_update_with_oauth2(
"add_permission_view_menu",
)
database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = {
"oauth2_client_info": oauth2_client_info,
}
UpdateDatabaseCommand(1, {}).run()
add_permission_view_menu.assert_not_called()
database_needs_oauth2.purge_oauth2_tokens.assert_not_called()
def test_update_with_oauth2_changed(
mocker: MockerFixture,
database_needs_oauth2: MockerFixture,
) -> None:
"""
Test that the database can be updated even if OAuth2 is needed to connect.
"""
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
DatabaseDAO.find_by_id.return_value = database_needs_oauth2
DatabaseDAO.update.return_value = database_needs_oauth2
find_permission_view_menu = mocker.patch.object(
security_manager,
"find_permission_view_menu",
)
find_permission_view_menu.side_effect = [
None, # schema1 has no permissions
"[my_db].[schema2]", # second schema already exists
]
add_permission_view_menu = mocker.patch.object(
security_manager,
"add_permission_view_menu",
)
modified_oauth2_client_info = oauth2_client_info.copy()
modified_oauth2_client_info["scope"] = "scope-b"
database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = {
"oauth2_client_info": modified_oauth2_client_info,
}
UpdateDatabaseCommand(1, {}).run()
add_permission_view_menu.assert_not_called()
database_needs_oauth2.purge_oauth2_tokens.assert_called()

View File

@ -23,6 +23,7 @@ import pytest
from pytest_mock import MockerFixture
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm.session import Session
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
@ -603,3 +604,82 @@ def test_engine_context_manager(mocker: MockerFixture) -> None:
source=None,
sqlalchemy_uri="trino://",
)
def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.
"""
from flask_appbuilder.security.sqla.models import Role, User # noqa: F401
from superset.models.core import Database, DatabaseUserOAuth2Tokens
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
user = User(
first_name="Alice",
last_name="Doe",
email="adoe@example.org",
username="adoe",
)
session.add(user)
session.flush()
database1 = Database(database_name="my_oauth2_db", sqlalchemy_uri="sqlite://")
database2 = Database(database_name="my_other_oauth2_db", sqlalchemy_uri="sqlite://")
session.add_all([database1, database2])
session.flush()
tokens = [
DatabaseUserOAuth2Tokens(
user_id=user.id,
database_id=database1.id,
access_token="my_access_token",
access_token_expiration=datetime(2023, 1, 1),
refresh_token="my_refresh_token",
),
DatabaseUserOAuth2Tokens(
user_id=user.id,
database_id=database2.id,
access_token="my_other_access_token",
access_token_expiration=datetime(2024, 1, 1),
refresh_token="my_other_refresh_token",
),
]
session.add_all(tokens)
session.flush()
assert len(session.query(DatabaseUserOAuth2Tokens).all()) == 2
token = (
session.query(DatabaseUserOAuth2Tokens)
.filter_by(database_id=database1.id)
.one()
)
assert token.user_id == user.id
assert token.database_id == database1.id
assert token.access_token == "my_access_token"
assert token.access_token_expiration == datetime(2023, 1, 1)
assert token.refresh_token == "my_refresh_token"
database1.purge_oauth2_tokens()
# confirm token was deleted
token = (
session.query(DatabaseUserOAuth2Tokens)
.filter_by(database_id=database1.id)
.one_or_none()
)
assert token is None
# make sure other DB tokens weren't deleted
token = (
session.query(DatabaseUserOAuth2Tokens)
.filter_by(database_id=database2.id)
.one()
)
assert token is not None
# make sure database was not deleted... just in case
database = session.query(Database).filter_by(id=database1.id).one()
assert database.name == "my_oauth2_db"