From 41ed37ab02b79ef00b99c87d5678d52eeab0bed6 Mon Sep 17 00:00:00 2001 From: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com> Date: Fri, 10 Jan 2025 15:54:53 -0300 Subject: [PATCH] fix(oauth): Handle updates to the OAuth config (#31777) --- superset/commands/database/update.py | 29 +++--- .../commands/databases/update_test.py | 89 +++++++++++++++++-- 2 files changed, 98 insertions(+), 20 deletions(-) diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 85439afa7..8cc5e4245 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -44,6 +44,7 @@ from superset.databases.ssh_tunnel.models import SSHTunnel from superset.db_engine_specs.base import GenericDBException from superset.exceptions import OAuth2RedirectError from superset.models.core import Database +from superset.utils import json from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -66,22 +67,23 @@ class UpdateDatabaseCommand(BaseCommand): self.validate() - # unmask ``encrypted_extra`` - self._properties["encrypted_extra"] = ( - self._model.db_engine_spec.unmask_encrypted_extra( - self._model.encrypted_extra, - self._properties.pop("masked_encrypted_extra", "{}"), + if "masked_encrypted_extra" in self._properties: + # unmask ``encrypted_extra`` + self._properties["encrypted_extra"] = ( + self._model.db_engine_spec.unmask_encrypted_extra( + self._model.encrypted_extra, + self._properties["masked_encrypted_extra"], + ) ) - ) + + # Depending on the changes to the OAuth2 configuration we may need to purge + # existing personal tokens. + self._handle_oauth2() # if the database name changed we need to update any existing permissions, # since they're name based original_database_name = self._model.database_name - # Depending on the changes to the OAuth2 configuration we may need to purge - # existing personal tokens. - self._handle_oauth2() - database = DatabaseDAO.update(self._model, self._properties) database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) @@ -99,11 +101,16 @@ class UpdateDatabaseCommand(BaseCommand): if not self._model: return + if self._properties["encrypted_extra"] is None: + self._model.purge_oauth2_tokens() + return + current_config = self._model.get_oauth2_config() if not current_config: return - new_config = self._properties["encrypted_extra"].get("oauth2_client_info", {}) + encrypted_extra = json.loads(self._properties["encrypted_extra"]) + new_config = encrypted_extra.get("oauth2_client_info", {}) # Keys that require purging personal tokens because they probably are no longer # valid. For example, if the scope has changed the existing tokens are still diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index 7ca3d70dc..a7a1dd97c 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -21,8 +21,10 @@ import pytest from pytest_mock import MockerFixture from superset.commands.database.update import UpdateDatabaseCommand +from superset.db_engine_specs.base import BaseEngineSpec from superset.exceptions import OAuth2RedirectError from superset.extensions import security_manager +from superset.utils import json oauth2_client_info = { "id": "client_id", @@ -82,7 +84,10 @@ def database_needs_oauth2(mocker: MockerFixture) -> MagicMock: "tab_id", "redirect_uri", ) - database.get_oauth2_config.return_value = oauth2_client_info + database.encrypted_extra = json.dumps({"oauth2_client_info": oauth2_client_info}) + database.db_engine_spec.unmask_encrypted_extra = ( + BaseEngineSpec.unmask_encrypted_extra + ) return database @@ -332,10 +337,6 @@ def test_update_with_oauth2( "add_permission_view_menu", ) - database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = { - "oauth2_client_info": oauth2_client_info, - } - UpdateDatabaseCommand(1, {}).run() add_permission_view_menu.assert_not_called() @@ -368,11 +369,81 @@ def test_update_with_oauth2_changed( modified_oauth2_client_info = oauth2_client_info.copy() modified_oauth2_client_info["scope"] = "scope-b" - database_needs_oauth2.db_engine_spec.unmask_encrypted_extra.return_value = { - "oauth2_client_info": modified_oauth2_client_info, - } - UpdateDatabaseCommand(1, {}).run() + UpdateDatabaseCommand( + 1, + { + "masked_encrypted_extra": json.dumps( + {"oauth2_client_info": modified_oauth2_client_info} + ) + }, + ).run() add_permission_view_menu.assert_not_called() database_needs_oauth2.purge_oauth2_tokens.assert_called() + + +def test_remove_oauth_config_purges_tokens( + mocker: MockerFixture, + database_needs_oauth2: MockerFixture, +) -> None: + """ + Test that removing the OAuth config from a database purges existing tokens. + """ + DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 + DatabaseDAO.find_by_id.return_value = database_needs_oauth2 + DatabaseDAO.update.return_value = database_needs_oauth2 + + find_permission_view_menu = mocker.patch.object( + security_manager, + "find_permission_view_menu", + ) + find_permission_view_menu.side_effect = [ + None, + "[my_db].[schema2]", + ] + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + UpdateDatabaseCommand(1, {"masked_encrypted_extra": None}).run() + + add_permission_view_menu.assert_not_called() + database_needs_oauth2.purge_oauth2_tokens.assert_called() + + UpdateDatabaseCommand(1, {"masked_encrypted_extra": "{}"}).run() + + add_permission_view_menu.assert_not_called() + database_needs_oauth2.purge_oauth2_tokens.assert_called() + + +def test_update_other_fields_dont_affect_oauth( + mocker: MockerFixture, + database_needs_oauth2: MockerFixture, +) -> None: + """ + Test that not including ``masked_encrypted_extra`` in the payload does not + touch the OAuth config. + """ + DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") # noqa: N806 + DatabaseDAO.find_by_id.return_value = database_needs_oauth2 + DatabaseDAO.update.return_value = database_needs_oauth2 + + find_permission_view_menu = mocker.patch.object( + security_manager, + "find_permission_view_menu", + ) + find_permission_view_menu.side_effect = [ + None, + "[my_db].[schema2]", + ] + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + UpdateDatabaseCommand(1, {"database_name": "New DB name"}).run() + + add_permission_view_menu.assert_not_called() + database_needs_oauth2.purge_oauth2_tokens.assert_not_called()