From d535f3fe56bc9d3b8400ef806119121c7cc0af31 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 11 Jul 2024 15:10:02 -0400 Subject: [PATCH] fix: make catalog migration lenient (#29549) --- superset/migrations/shared/catalogs.py | 117 ++++++++++------ .../migrations/shared/catalogs_test.py | 125 ++++++++++++++++++ 2 files changed, 204 insertions(+), 38 deletions(-) diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py index 6c03faec4..b4c8658cb 100644 --- a/superset/migrations/shared/catalogs.py +++ b/superset/migrations/shared/catalogs.py @@ -23,6 +23,7 @@ from typing import Any, Type import sqlalchemy as sa from alembic import op from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session from superset import db, security_manager from superset.daos.database import DatabaseDAO @@ -86,6 +87,24 @@ class Slice(Base): schema_perm = sa.Column(sa.String(1000)) +def get_schemas(database_name: str) -> list[str]: + """ + Read all known schemas from the schema permissions. + """ + query = f""" +SELECT + avm.name +FROM ab_view_menu avm +JOIN ab_permission_view apv ON avm.id = apv.view_menu_id +JOIN ab_permission ap ON apv.permission_id = ap.id +WHERE + avm.name LIKE '[{database_name}]%' AND + ap.name = 'schema_access'; + """ + # [PostgreSQL].[postgres].[public] => public + return sorted({row[0].split(".")[-1][1:-1] for row in op.execute(query)}) + + def upgrade_catalog_perms(engines: set[str] | None = None) -> None: """ Update models when catalogs are introduced in a DB engine spec. @@ -116,25 +135,7 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None: ) add_pvms(session, {perm: ("catalog_access",)}) - # update schema_perms - ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) - for schema in database.get_all_schema_names( - catalog=catalog, - cache=False, - ssh_tunnel=ssh_tunnel, - ): - perm = security_manager.get_schema_perm( - database.database_name, - None, - schema, - ) - existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() - if existing_pvm: - existing_pvm.name = security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ) + upgrade_schema_perms(database, catalog, session) # update existing models models = [ @@ -166,6 +167,35 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None: session.commit() +def upgrade_schema_perms(database: Database, catalog: str, session: Session) -> None: + """ + Rename existing schema permissions to include the catalog. + """ + ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) + try: + schemas = database.get_all_schema_names( + catalog=catalog, + cache=False, + ssh_tunnel=ssh_tunnel, + ) + except Exception: # pylint: disable=broad-except + schemas = get_schemas(database.database_name) + + for schema in schemas: + perm = security_manager.get_schema_perm( + database.database_name, + None, + schema, + ) + existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() + if existing_pvm: + existing_pvm.name = security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ) + + def downgrade_catalog_perms(engines: set[str] | None = None) -> None: """ Reverse the process of `upgrade_catalog_perms`. @@ -183,25 +213,7 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None: if catalog is None: continue - # update schema_perms - ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) - for schema in database.get_all_schema_names( - catalog=catalog, - cache=False, - ssh_tunnel=ssh_tunnel, - ): - perm = security_manager.get_schema_perm( - database.database_name, - catalog, - schema, - ) - existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() - if existing_pvm: - existing_pvm.name = security_manager.get_schema_perm( - database.database_name, - None, - schema, - ) + downgrade_schema_perms(database, catalog, session) # update existing models models = [ @@ -231,3 +243,32 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None: chart.schema_perm = schema_perm session.commit() + + +def downgrade_schema_perms(database: Database, catalog: str, session: Session) -> None: + """ + Rename existing schema permissions to omit the catalog. + """ + ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) + try: + schemas = database.get_all_schema_names( + catalog=catalog, + cache=False, + ssh_tunnel=ssh_tunnel, + ) + except Exception: # pylint: disable=broad-except + schemas = get_schemas(database.database_name) + + for schema in schemas: + perm = security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ) + existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none() + if existing_pvm: + existing_pvm.name = security_manager.get_schema_perm( + database.database_name, + None, + schema, + ) diff --git a/tests/unit_tests/migrations/shared/catalogs_test.py b/tests/unit_tests/migrations/shared/catalogs_test.py index ca715bec9..78ef52221 100644 --- a/tests/unit_tests/migrations/shared/catalogs_test.py +++ b/tests/unit_tests/migrations/shared/catalogs_test.py @@ -143,3 +143,128 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None: ("[my_db].[public]",), ("[my_db].[db]",), ] + + +def test_upgrade_catalog_perms_graceful( + mocker: MockerFixture, + session: Session, +) -> None: + """ + Test the `upgrade_catalog_perms` function when it fails to connect to the DB. + + During the migration we try to connect to the analytical database to get the list of + schemas. This should fail gracefully and not raise an exception, since the database + could be offline, and the permissions can be generated later then the admin enables + catalog browsing on the database (permissions are always synced on a DB update, see + `UpdateDatabaseCommand`). + """ + from superset.connectors.sqla.models import SqlaTable + from superset.models.core import Database + from superset.models.slice import Slice + from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState + + engine = session.get_bind() + Database.metadata.create_all(engine) + + mocker.patch("superset.migrations.shared.catalogs.op") + db = mocker.patch("superset.migrations.shared.catalogs.db") + db.Session.return_value = session + + mocker.patch.object( + Database, + "get_all_schema_names", + side_effect=Exception("Failed to connect to the database"), + ) + mocker.patch("superset.migrations.shared.catalogs.op", session) + + database = Database( + database_name="my_db", + sqlalchemy_uri="postgresql://localhost/db", + ) + dataset = SqlaTable( + table_name="my_table", + database=database, + catalog=None, + schema="public", + schema_perm="[my_db].[public]", + ) + session.add(dataset) + session.commit() + + chart = Slice( + slice_name="my_chart", + datasource_type="table", + datasource_id=dataset.id, + ) + query = Query( + client_id="foo", + database=database, + catalog=None, + schema="public", + ) + saved_query = SavedQuery( + database=database, + sql="SELECT * FROM public.t", + catalog=None, + schema="public", + ) + tab_state = TabState( + database=database, + catalog=None, + schema="public", + ) + table_schema = TableSchema( + database=database, + catalog=None, + schema="public", + ) + session.add_all([chart, query, saved_query, tab_state, table_schema]) + session.commit() + + # before migration + assert dataset.catalog is None + assert query.catalog is None + assert saved_query.catalog is None + assert tab_state.catalog is None + assert table_schema.catalog is None + assert dataset.schema_perm == "[my_db].[public]" + assert chart.schema_perm == "[my_db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[public]",), + ] + + upgrade_catalog_perms() + + # after migration + assert dataset.catalog == "db" + assert query.catalog == "db" + assert saved_query.catalog == "db" + assert tab_state.catalog == "db" + assert table_schema.catalog == "db" + assert dataset.schema_perm == "[my_db].[db].[public]" + assert chart.schema_perm == "[my_db].[db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[db].[public]",), + ("[my_db].[db]",), + ] + + downgrade_catalog_perms() + + # revert + assert dataset.catalog is None + assert query.catalog is None + assert saved_query.catalog is None + assert tab_state.catalog is None + assert table_schema.catalog is None + assert dataset.schema_perm == "[my_db].[public]" + assert chart.schema_perm == "[my_db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[public]",), + ("[my_db].[db]",), + ]