From 68499a1199415340137105c8a7d649f10bf3c71e Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 26 Nov 2024 15:57:01 -0500 Subject: [PATCH] feat: purge OAuth2 tokens when DB changes (#31164) --- superset/commands/database/update.py | 32 ++++++++ superset/config.py | 12 ++- superset/models/core.py | 14 +++- .../commands/databases/update_test.py | 52 ++++++++++++ tests/unit_tests/models/core_test.py | 80 +++++++++++++++++++ 5 files changed, 186 insertions(+), 4 deletions(-) diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 41dee7b79..85439afa7 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -78,6 +78,10 @@ class UpdateDatabaseCommand(BaseCommand): # since they're name based original_database_name = self._model.database_name + # Depending on the changes to the OAuth2 configuration we may need to purge + # existing personal tokens. + self._handle_oauth2() + database = DatabaseDAO.update(self._model, self._properties) database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) @@ -88,6 +92,34 @@ class UpdateDatabaseCommand(BaseCommand): return database + def _handle_oauth2(self) -> None: + """ + Handle changes in OAuth2. + """ + if not self._model: + return + + current_config = self._model.get_oauth2_config() + if not current_config: + return + + new_config = self._properties["encrypted_extra"].get("oauth2_client_info", {}) + + # Keys that require purging personal tokens because they probably are no longer + # valid. For example, if the scope has changed the existing tokens are still + # associated with the old scope. Similarly, if the endpoints changed the tokens + # are probably no longer valid. + keys = { + "id", + "scope", + "authorization_request_uri", + "token_request_uri", + } + for key in keys: + if current_config.get(key) != new_config.get(key): + self._model.purge_oauth2_tokens() + break + def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None: """ Delete, create, or update an SSH tunnel. diff --git a/superset/config.py b/superset/config.py index d278747bf..9e92bf79b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1543,9 +1543,12 @@ PREFERRED_DATABASES: list[str] = [ # one here. TEST_DATABASE_CONNECTION_TIMEOUT = timedelta(seconds=30) -# Details needed for databases that allows user to authenticate using personal -# OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more -# information. The scope and URIs are optional. +# Details needed for databases that allows user to authenticate using personal OAuth2 +# tokens. See https://github.com/apache/superset/issues/20300 for more information. The +# scope and URIs are usually optional. +# NOTE that if you change the id, scope, or URIs in this file, you probably need to purge +# the existing tokens from the database. This needs to be done by running a query to +# delete the existing tokens. DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = { # "Google Sheets": { # "id": "XXX.apps.googleusercontent.com", @@ -1561,14 +1564,17 @@ DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = { # "token_request_uri": "https://oauth2.googleapis.com/token", # }, } + # OAuth2 state is encoded in a JWT using the alogorithm below. DATABASE_OAUTH2_JWT_ALGORITHM = "HS256" + # By default the redirect URI points to /api/v1/database/oauth2/ and doesn't have to be # specified. If you're running multiple Superset instances you might want to have a # proxy handling the redirects, since redirect URIs need to be registered in the OAuth2 # applications. In that case, the proxy can forward the request to the correct instance # by looking at the `default_redirect_uri` attribute in the OAuth2 state object. # DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/" + # Timeout when fetching access and refresh tokens. DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30) diff --git a/superset/models/core.py b/superset/models/core.py index 1baf372c5..9b06932af 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -60,7 +60,7 @@ from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import ColumnElement, expression, Select -from superset import app, db_engine_specs, is_feature_enabled +from superset import app, db, db_engine_specs, is_feature_enabled from superset.commands.database.exceptions import DatabaseInvalidError from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK from superset.databases.utils import make_url_safe @@ -1136,6 +1136,18 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable """ return self.db_engine_spec.start_oauth2_dance(self) + def purge_oauth2_tokens(self) -> None: + """ + Delete all OAuth2 tokens associated with this database. + + This is needed when the configuration changes. For example, a new client ID and + secret probably will require new tokens. The same is valid for changes in the + scope or in the endpoints. + """ + db.session.query(DatabaseUserOAuth2Tokens).filter( + DatabaseUserOAuth2Tokens.id == self.id + ).delete() + sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) sqla.event.listen(Database, "after_update", security_manager.database_after_update) diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index d7b60f85d..dfec42180 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -24,6 +24,16 @@ from superset.commands.database.update import UpdateDatabaseCommand from superset.exceptions import OAuth2RedirectError from superset.extensions import security_manager +oauth2_client_info = { + "id": "client_id", + "secret": "client_secret", + "scope": "scope-a", + "redirect_uri": "redirect_uri", + "authorization_request_uri": "auth_uri", + "token_request_uri": "token_uri", + "request_content_type": "json", +} + @pytest.fixture() def database_with_catalog(mocker: MockerFixture) -> MagicMock: @@ -72,6 +82,7 @@ def database_needs_oauth2(mocker: MockerFixture) -> MagicMock: "tab_id", "redirect_uri", ) + database.get_oauth2_config.return_value = oauth2_client_info return database @@ -321,6 +332,47 @@ def test_update_with_oauth2( "add_permission_view_menu", ) + database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = { + "oauth2_client_info": oauth2_client_info, + } + UpdateDatabaseCommand(1, {}).run() add_permission_view_menu.assert_not_called() + database_needs_oauth2.purge_oauth2_tokens.assert_not_called() + + +def test_update_with_oauth2_changed( + mocker: MockerFixture, + database_needs_oauth2: MockerFixture, +) -> None: + """ + Test that the database can be updated even if OAuth2 is needed to connect. + """ + DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") + DatabaseDAO.find_by_id.return_value = database_needs_oauth2 + DatabaseDAO.update.return_value = database_needs_oauth2 + + find_permission_view_menu = mocker.patch.object( + security_manager, + "find_permission_view_menu", + ) + find_permission_view_menu.side_effect = [ + None, # schema1 has no permissions + "[my_db].[schema2]", # second schema already exists + ] + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + modified_oauth2_client_info = oauth2_client_info.copy() + modified_oauth2_client_info["scope"] = "scope-b" + database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = { + "oauth2_client_info": modified_oauth2_client_info, + } + + UpdateDatabaseCommand(1, {}).run() + + add_permission_view_menu.assert_not_called() + database_needs_oauth2.purge_oauth2_tokens.assert_called() diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 1dff4784e..8b29116f5 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -23,6 +23,7 @@ import pytest from pytest_mock import MockerFixture from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import make_url +from sqlalchemy.orm.session import Session from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.errors import SupersetErrorType @@ -603,3 +604,82 @@ def test_engine_context_manager(mocker: MockerFixture) -> None: source=None, sqlalchemy_uri="trino://", ) + + +def test_purge_oauth2_tokens(session: Session) -> None: + """ + Test the `purge_oauth2_tokens` method. + """ + from flask_appbuilder.security.sqla.models import Role, User # noqa: F401 + + from superset.models.core import Database, DatabaseUserOAuth2Tokens + + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + user = User( + first_name="Alice", + last_name="Doe", + email="adoe@example.org", + username="adoe", + ) + session.add(user) + session.flush() + + database1 = Database(database_name="my_oauth2_db", sqlalchemy_uri="sqlite://") + database2 = Database(database_name="my_other_oauth2_db", sqlalchemy_uri="sqlite://") + session.add_all([database1, database2]) + session.flush() + + tokens = [ + DatabaseUserOAuth2Tokens( + user_id=user.id, + database_id=database1.id, + access_token="my_access_token", + access_token_expiration=datetime(2023, 1, 1), + refresh_token="my_refresh_token", + ), + DatabaseUserOAuth2Tokens( + user_id=user.id, + database_id=database2.id, + access_token="my_other_access_token", + access_token_expiration=datetime(2024, 1, 1), + refresh_token="my_other_refresh_token", + ), + ] + session.add_all(tokens) + session.flush() + + assert len(session.query(DatabaseUserOAuth2Tokens).all()) == 2 + + token = ( + session.query(DatabaseUserOAuth2Tokens) + .filter_by(database_id=database1.id) + .one() + ) + assert token.user_id == user.id + assert token.database_id == database1.id + assert token.access_token == "my_access_token" + assert token.access_token_expiration == datetime(2023, 1, 1) + assert token.refresh_token == "my_refresh_token" + + database1.purge_oauth2_tokens() + + # confirm token was deleted + token = ( + session.query(DatabaseUserOAuth2Tokens) + .filter_by(database_id=database1.id) + .one_or_none() + ) + assert token is None + + # make sure other DB tokens weren't deleted + token = ( + session.query(DatabaseUserOAuth2Tokens) + .filter_by(database_id=database2.id) + .one() + ) + assert token is not None + + # make sure database was not deleted... just in case + database = session.query(Database).filter_by(id=database1.id).one() + assert database.name == "my_oauth2_db"