From b66c0f8d30722a5d9062f322d60aa7c750b669d9 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 16 Jul 2024 15:53:25 -0400 Subject: [PATCH] fix: schemas for upload API (#29604) --- superset/models/core.py | 4 ++-- .../integration_tests/dashboards/api_tests.py | 2 +- tests/unit_tests/models/core_test.py | 20 +++++++++++++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/superset/models/core.py b/superset/models/core.py index 78bbf55cd..512c5a93e 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -982,7 +982,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable def get_schema_access_for_file_upload( # pylint: disable=invalid-name self, - ) -> list[str]: + ) -> set[str]: allowed_databases = self.get_extra().get("schemas_allowed_for_file_upload", []) if isinstance(allowed_databases, str): @@ -993,7 +993,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable self, g.user ) allowed_databases += extra_allowed_databases - return sorted(set(allowed_databases)) + return set(allowed_databases) @property def sqlalchemy_uri_decrypted(self) -> str: diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index b4e2958cc..259b9485f 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -2621,7 +2621,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas # Clean up system tags tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] - self.assertEqual(tag_list, new_tags) + self.assertEqual(sorted(tag_list), sorted(new_tags)) @pytest.mark.usefixtures("create_dashboard_with_tag") def test_update_dashboard_remove_tags_can_write_on_tag(self): diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index e563971a5..6f588cde2 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -461,3 +461,23 @@ def test_raw_connection_oauth(mocker: MockerFixture) -> None: with database.get_raw_connection() as conn: conn.cursor() assert str(excinfo.value) == "You don't have permission to access the data." + + +def test_get_schema_access_for_file_upload() -> None: + """ + Test the `get_schema_access_for_file_upload` method. + """ + database = Database( + database_name="first-database", + sqlalchemy_uri="gsheets://", + extra=json.dumps( + { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_file_upload": '["public"]', + } + ), + ) + + assert database.get_schema_access_for_file_upload() == {"public"}