fix: make catalog migration lenient (#29549)

This commit is contained in:
Beto Dealmeida 2024-07-11 15:10:02 -04:00 committed by GitHub
parent 33b934cbb3
commit d535f3fe56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 204 additions and 38 deletions

View File

@ -23,6 +23,7 @@ from typing import Any, Type
import sqlalchemy as sa
from alembic import op
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from superset import db, security_manager
from superset.daos.database import DatabaseDAO
@ -86,6 +87,24 @@ class Slice(Base):
schema_perm = sa.Column(sa.String(1000))
def get_schemas(database_name: str) -> list[str]:
"""
Read all known schemas from the schema permissions.
"""
query = f"""
SELECT
avm.name
FROM ab_view_menu avm
JOIN ab_permission_view apv ON avm.id = apv.view_menu_id
JOIN ab_permission ap ON apv.permission_id = ap.id
WHERE
avm.name LIKE '[{database_name}]%' AND
ap.name = 'schema_access';
"""
# [PostgreSQL].[postgres].[public] => public
return sorted({row[0].split(".")[-1][1:-1] for row in op.execute(query)})
def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Update models when catalogs are introduced in a DB engine spec.
@ -116,25 +135,7 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
)
add_pvms(session, {perm: ("catalog_access",)})
# update schema_perms
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
perm = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
upgrade_schema_perms(database, catalog, session)
# update existing models
models = [
@ -166,6 +167,35 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
session.commit()
def upgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
"""
Rename existing schema permissions to include the catalog.
"""
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
try:
schemas = database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
)
except Exception: # pylint: disable=broad-except
schemas = get_schemas(database.database_name)
for schema in schemas:
perm = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Reverse the process of `upgrade_catalog_perms`.
@ -183,25 +213,7 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
if catalog is None:
continue
# update schema_perms
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
perm = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
downgrade_schema_perms(database, catalog, session)
# update existing models
models = [
@ -231,3 +243,32 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
chart.schema_perm = schema_perm
session.commit()
def downgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
"""
Rename existing schema permissions to omit the catalog.
"""
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
try:
schemas = database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
)
except Exception: # pylint: disable=broad-except
schemas = get_schemas(database.database_name)
for schema in schemas:
perm = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
if existing_pvm:
existing_pvm.name = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)

View File

@ -143,3 +143,128 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
("[my_db].[public]",),
("[my_db].[db]",),
]
def test_upgrade_catalog_perms_graceful(
mocker: MockerFixture,
session: Session,
) -> None:
"""
Test the `upgrade_catalog_perms` function when it fails to connect to the DB.
During the migration we try to connect to the analytical database to get the list of
schemas. This should fail gracefully and not raise an exception, since the database
could be offline, and the permissions can be generated later then the admin enables
catalog browsing on the database (permissions are always synced on a DB update, see
`UpdateDatabaseCommand`).
"""
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
engine = session.get_bind()
Database.metadata.create_all(engine)
mocker.patch("superset.migrations.shared.catalogs.op")
db = mocker.patch("superset.migrations.shared.catalogs.db")
db.Session.return_value = session
mocker.patch.object(
Database,
"get_all_schema_names",
side_effect=Exception("Failed to connect to the database"),
)
mocker.patch("superset.migrations.shared.catalogs.op", session)
database = Database(
database_name="my_db",
sqlalchemy_uri="postgresql://localhost/db",
)
dataset = SqlaTable(
table_name="my_table",
database=database,
catalog=None,
schema="public",
schema_perm="[my_db].[public]",
)
session.add(dataset)
session.commit()
chart = Slice(
slice_name="my_chart",
datasource_type="table",
datasource_id=dataset.id,
)
query = Query(
client_id="foo",
database=database,
catalog=None,
schema="public",
)
saved_query = SavedQuery(
database=database,
sql="SELECT * FROM public.t",
catalog=None,
schema="public",
)
tab_state = TabState(
database=database,
catalog=None,
schema="public",
)
table_schema = TableSchema(
database=database,
catalog=None,
schema="public",
)
session.add_all([chart, query, saved_query, tab_state, table_schema])
session.commit()
# before migration
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
]
upgrade_catalog_perms()
# after migration
assert dataset.catalog == "db"
assert query.catalog == "db"
assert saved_query.catalog == "db"
assert tab_state.catalog == "db"
assert table_schema.catalog == "db"
assert dataset.schema_perm == "[my_db].[db].[public]"
assert chart.schema_perm == "[my_db].[db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[db].[public]",),
("[my_db].[db]",),
]
downgrade_catalog_perms()
# revert
assert dataset.catalog is None
assert query.catalog is None
assert saved_query.catalog is None
assert tab_state.catalog is None
assert table_schema.catalog is None
assert dataset.schema_perm == "[my_db].[public]"
assert chart.schema_perm == "[my_db].[public]"
assert session.query(ViewMenu.name).all() == [
("[my_db].(id:1)",),
("[my_db].[my_table](id:1)",),
("[my_db].[public]",),
("[my_db].[db]",),
]