fix: make catalog migration lenient (#29549)
This commit is contained in:
parent
33b934cbb3
commit
d535f3fe56
|
|
@ -23,6 +23,7 @@ from typing import Any, Type
|
|||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.daos.database import DatabaseDAO
|
||||
|
|
@ -86,6 +87,24 @@ class Slice(Base):
|
|||
schema_perm = sa.Column(sa.String(1000))
|
||||
|
||||
|
||||
def get_schemas(database_name: str) -> list[str]:
|
||||
"""
|
||||
Read all known schemas from the schema permissions.
|
||||
"""
|
||||
query = f"""
|
||||
SELECT
|
||||
avm.name
|
||||
FROM ab_view_menu avm
|
||||
JOIN ab_permission_view apv ON avm.id = apv.view_menu_id
|
||||
JOIN ab_permission ap ON apv.permission_id = ap.id
|
||||
WHERE
|
||||
avm.name LIKE '[{database_name}]%' AND
|
||||
ap.name = 'schema_access';
|
||||
"""
|
||||
# [PostgreSQL].[postgres].[public] => public
|
||||
return sorted({row[0].split(".")[-1][1:-1] for row in op.execute(query)})
|
||||
|
||||
|
||||
def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
||||
"""
|
||||
Update models when catalogs are introduced in a DB engine spec.
|
||||
|
|
@ -116,25 +135,7 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
|||
)
|
||||
add_pvms(session, {perm: ("catalog_access",)})
|
||||
|
||||
# update schema_perms
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
for schema in database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
):
|
||||
perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
|
||||
if existing_pvm:
|
||||
existing_pvm.name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
upgrade_schema_perms(database, catalog, session)
|
||||
|
||||
# update existing models
|
||||
models = [
|
||||
|
|
@ -166,6 +167,35 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
|||
session.commit()
|
||||
|
||||
|
||||
def upgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
|
||||
"""
|
||||
Rename existing schema permissions to include the catalog.
|
||||
"""
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
try:
|
||||
schemas = database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
schemas = get_schemas(database.database_name)
|
||||
|
||||
for schema in schemas:
|
||||
perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
|
||||
if existing_pvm:
|
||||
existing_pvm.name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
|
||||
|
||||
def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
||||
"""
|
||||
Reverse the process of `upgrade_catalog_perms`.
|
||||
|
|
@ -183,25 +213,7 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
|||
if catalog is None:
|
||||
continue
|
||||
|
||||
# update schema_perms
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
for schema in database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
):
|
||||
perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
|
||||
if existing_pvm:
|
||||
existing_pvm.name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
downgrade_schema_perms(database, catalog, session)
|
||||
|
||||
# update existing models
|
||||
models = [
|
||||
|
|
@ -231,3 +243,32 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
|||
chart.schema_perm = schema_perm
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
|
||||
"""
|
||||
Rename existing schema permissions to omit the catalog.
|
||||
"""
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
try:
|
||||
schemas = database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
schemas = get_schemas(database.database_name)
|
||||
|
||||
for schema in schemas:
|
||||
perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
|
||||
if existing_pvm:
|
||||
existing_pvm.name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -143,3 +143,128 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
|||
("[my_db].[public]",),
|
||||
("[my_db].[db]",),
|
||||
]
|
||||
|
||||
|
||||
def test_upgrade_catalog_perms_graceful(
|
||||
mocker: MockerFixture,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Test the `upgrade_catalog_perms` function when it fails to connect to the DB.
|
||||
|
||||
During the migration we try to connect to the analytical database to get the list of
|
||||
schemas. This should fail gracefully and not raise an exception, since the database
|
||||
could be offline, and the permissions can be generated later then the admin enables
|
||||
catalog browsing on the database (permissions are always synced on a DB update, see
|
||||
`UpdateDatabaseCommand`).
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
|
||||
|
||||
engine = session.get_bind()
|
||||
Database.metadata.create_all(engine)
|
||||
|
||||
mocker.patch("superset.migrations.shared.catalogs.op")
|
||||
db = mocker.patch("superset.migrations.shared.catalogs.db")
|
||||
db.Session.return_value = session
|
||||
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_all_schema_names",
|
||||
side_effect=Exception("Failed to connect to the database"),
|
||||
)
|
||||
mocker.patch("superset.migrations.shared.catalogs.op", session)
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="postgresql://localhost/db",
|
||||
)
|
||||
dataset = SqlaTable(
|
||||
table_name="my_table",
|
||||
database=database,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
schema_perm="[my_db].[public]",
|
||||
)
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
|
||||
chart = Slice(
|
||||
slice_name="my_chart",
|
||||
datasource_type="table",
|
||||
datasource_id=dataset.id,
|
||||
)
|
||||
query = Query(
|
||||
client_id="foo",
|
||||
database=database,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
saved_query = SavedQuery(
|
||||
database=database,
|
||||
sql="SELECT * FROM public.t",
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
tab_state = TabState(
|
||||
database=database,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
table_schema = TableSchema(
|
||||
database=database,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
)
|
||||
session.add_all([chart, query, saved_query, tab_state, table_schema])
|
||||
session.commit()
|
||||
|
||||
# before migration
|
||||
assert dataset.catalog is None
|
||||
assert query.catalog is None
|
||||
assert saved_query.catalog is None
|
||||
assert tab_state.catalog is None
|
||||
assert table_schema.catalog is None
|
||||
assert dataset.schema_perm == "[my_db].[public]"
|
||||
assert chart.schema_perm == "[my_db].[public]"
|
||||
assert session.query(ViewMenu.name).all() == [
|
||||
("[my_db].(id:1)",),
|
||||
("[my_db].[my_table](id:1)",),
|
||||
("[my_db].[public]",),
|
||||
]
|
||||
|
||||
upgrade_catalog_perms()
|
||||
|
||||
# after migration
|
||||
assert dataset.catalog == "db"
|
||||
assert query.catalog == "db"
|
||||
assert saved_query.catalog == "db"
|
||||
assert tab_state.catalog == "db"
|
||||
assert table_schema.catalog == "db"
|
||||
assert dataset.schema_perm == "[my_db].[db].[public]"
|
||||
assert chart.schema_perm == "[my_db].[db].[public]"
|
||||
assert session.query(ViewMenu.name).all() == [
|
||||
("[my_db].(id:1)",),
|
||||
("[my_db].[my_table](id:1)",),
|
||||
("[my_db].[db].[public]",),
|
||||
("[my_db].[db]",),
|
||||
]
|
||||
|
||||
downgrade_catalog_perms()
|
||||
|
||||
# revert
|
||||
assert dataset.catalog is None
|
||||
assert query.catalog is None
|
||||
assert saved_query.catalog is None
|
||||
assert tab_state.catalog is None
|
||||
assert table_schema.catalog is None
|
||||
assert dataset.schema_perm == "[my_db].[public]"
|
||||
assert chart.schema_perm == "[my_db].[public]"
|
||||
assert session.query(ViewMenu.name).all() == [
|
||||
("[my_db].(id:1)",),
|
||||
("[my_db].[my_table](id:1)",),
|
||||
("[my_db].[public]",),
|
||||
("[my_db].[db]",),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue