fix: catalog upgrade/downgrade (#29780)
This commit is contained in:
parent
8891f04f11
commit
525e837c5b
|
|
@ -92,6 +92,12 @@ ColumnTypeMapping = tuple[
|
|||
|
||||
logger = logging.getLogger()
|
||||
|
||||
# When connecting to a database it's hard to catch specific exceptions, since we support
|
||||
# more than 50 different database drivers. Usually the try/except block will catch the
|
||||
# generic `Exception` class, which requires a pylint disablee comment. To make it clear
|
||||
# that we know this is a necessary evil we create an alias, and catch it instead.
|
||||
GenericDBException = Exception
|
||||
|
||||
|
||||
def convert_inspector_columns(cols: list[SQLAColumnType]) -> list[ResultSetColumnType]:
|
||||
result_set_columns: list[ResultSetColumnType] = []
|
||||
|
|
@ -406,7 +412,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
#
|
||||
# When this is changed to true in a DB engine spec it MUST support the
|
||||
# `get_default_catalog` and `get_catalog_names` methods. In addition, you MUST write
|
||||
# a database migration updating any existing schema permissions.
|
||||
# a database migration updating any existing schema permissions using the helper
|
||||
# `upgrade_catalog_perms`.
|
||||
supports_catalog = False
|
||||
|
||||
# Can the catalog be changed on a per-query basis?
|
||||
|
|
|
|||
|
|
@ -434,8 +434,8 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
|
|||
cls,
|
||||
database: Database,
|
||||
) -> str | None:
|
||||
with database.get_inspector() as inspector:
|
||||
return inspector.bind.execute("SELECT current_catalog()").scalar()
|
||||
with database.get_sqla_engine() as engine:
|
||||
return engine.execute("SELECT current_catalog()").scalar()
|
||||
|
||||
@classmethod
|
||||
def get_prequeries(
|
||||
|
|
|
|||
|
|
@ -26,8 +26,13 @@ from sqlalchemy.ext.declarative import declarative_base
|
|||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.migrations.shared.security_converge import add_pvms, ViewMenu
|
||||
from superset.db_engine_specs.base import GenericDBException
|
||||
from superset.migrations.shared.security_converge import (
|
||||
add_pvms,
|
||||
Permission,
|
||||
PermissionView,
|
||||
ViewMenu,
|
||||
)
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -41,7 +46,9 @@ class SqlaTable(Base):
|
|||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
database_id = sa.Column(sa.Integer, nullable=False)
|
||||
perm = sa.Column(sa.String(1000))
|
||||
schema_perm = sa.Column(sa.String(1000))
|
||||
catalog_perm = sa.Column(sa.String(1000), nullable=True, default=None)
|
||||
schema = sa.Column(sa.String(255))
|
||||
catalog = sa.Column(sa.String(256), nullable=True, default=None)
|
||||
|
||||
|
|
@ -84,41 +91,47 @@ class Slice(Base):
|
|||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
datasource_id = sa.Column(sa.Integer)
|
||||
datasource_type = sa.Column(sa.String(200))
|
||||
catalog_perm = sa.Column(sa.String(1000), nullable=True, default=None)
|
||||
schema_perm = sa.Column(sa.String(1000))
|
||||
|
||||
|
||||
def get_schemas(database_name: str) -> list[str]:
|
||||
def get_known_schemas(database_name: str, session: Session) -> list[str]:
|
||||
"""
|
||||
Read all known schemas from the schema permissions.
|
||||
Read all known schemas from the existing schema permissions.
|
||||
"""
|
||||
query = f"""
|
||||
SELECT
|
||||
avm.name
|
||||
FROM ab_view_menu avm
|
||||
JOIN ab_permission_view apv ON avm.id = apv.view_menu_id
|
||||
JOIN ab_permission ap ON apv.permission_id = ap.id
|
||||
WHERE
|
||||
avm.name LIKE '[{database_name}]%' AND
|
||||
ap.name = 'schema_access';
|
||||
"""
|
||||
# [PostgreSQL].[postgres].[public] => public
|
||||
conn = op.get_bind()
|
||||
return sorted({row[0].split(".")[-1][1:-1] for row in conn.execute(query)})
|
||||
names = (
|
||||
session.query(ViewMenu.name)
|
||||
.join(PermissionView, ViewMenu.id == PermissionView.view_menu_id)
|
||||
.join(Permission, PermissionView.permission_id == Permission.id)
|
||||
.filter(
|
||||
ViewMenu.name.like(f"[{database_name}]%"),
|
||||
Permission.name == "schema_access",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return sorted({name[0][1:-1].split("].[")[-1] for name in names})
|
||||
|
||||
|
||||
def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
||||
"""
|
||||
Update models when catalogs are introduced in a DB engine spec.
|
||||
Update models and permissions when catalogs are introduced in a DB engine spec.
|
||||
|
||||
When an existing DB engine spec starts to support catalogs we need to:
|
||||
|
||||
- Add a `catalog_access` permission for each catalog.
|
||||
- Populate the `catalog` field with the default catalog for each related model.
|
||||
- Add `catalog_access` permissions for each catalog.
|
||||
- Rename existing `schema_access` permissions to include the default catalog.
|
||||
- Create `schema_access` permissions for each schema in the new catalogs.
|
||||
|
||||
Also, for all the relevant existing models we need to:
|
||||
|
||||
- Populate the `catalog` field with the default catalog.
|
||||
- Update `schema_perm` to include the default catalog.
|
||||
- Populate `catalog_perm` to include the default catalog.
|
||||
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
session = db.Session(bind=bind)
|
||||
|
||||
for database in session.query(Database).all():
|
||||
db_engine_spec = database.db_engine_spec
|
||||
if (
|
||||
|
|
@ -126,83 +139,204 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
|||
) or not db_engine_spec.supports_catalog:
|
||||
continue
|
||||
|
||||
catalog = database.get_default_catalog()
|
||||
if catalog is None:
|
||||
continue
|
||||
# For some databases, fetching the default catalog requires a connection to the
|
||||
# analytical DB. If we can't connect to the analytical DB during the migration
|
||||
# we should stop it, since we need the default catalog in order to update
|
||||
# existing models.
|
||||
if default_catalog := database.get_default_catalog():
|
||||
upgrade_database_catalogs(database, default_catalog, session)
|
||||
|
||||
perm = security_manager.get_catalog_perm(
|
||||
session.flush()
|
||||
|
||||
|
||||
def upgrade_database_catalogs(
|
||||
database: Database,
|
||||
default_catalog: str,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Upgrade a given database to support the default catalog.
|
||||
"""
|
||||
catalog_perm = security_manager.get_catalog_perm(
|
||||
database.database_name,
|
||||
default_catalog,
|
||||
)
|
||||
pvms: dict[str, tuple[str, ...]] = {catalog_perm: ("catalog_access",)}
|
||||
|
||||
# rename existing schema permissions to include the catalog, and also find any new
|
||||
# schemas
|
||||
new_schema_pvms = upgrade_schema_perms(database, default_catalog, session)
|
||||
pvms.update(new_schema_pvms)
|
||||
|
||||
# update existing models that have a `catalog` column so it points to the default
|
||||
# catalog
|
||||
models = [
|
||||
(Query, "database_id"),
|
||||
(SavedQuery, "db_id"),
|
||||
(TabState, "database_id"),
|
||||
(TableSchema, "database_id"),
|
||||
]
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id
|
||||
):
|
||||
instance.catalog = default_catalog
|
||||
|
||||
# update `schema_perm` and `catalog_perm` for tables and charts
|
||||
for table in session.query(SqlaTable).filter_by(
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
):
|
||||
schema_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
default_catalog,
|
||||
table.schema,
|
||||
)
|
||||
add_pvms(session, {perm: ("catalog_access",)})
|
||||
|
||||
upgrade_schema_perms(database, catalog, session)
|
||||
table.catalog = default_catalog
|
||||
table.catalog_perm = catalog_perm
|
||||
table.schema_perm = schema_perm
|
||||
|
||||
# update existing models
|
||||
models = [
|
||||
(Query, "database_id"),
|
||||
(SavedQuery, "db_id"),
|
||||
(TabState, "database_id"),
|
||||
(TableSchema, "database_id"),
|
||||
(SqlaTable, "database_id"),
|
||||
]
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id
|
||||
):
|
||||
instance.catalog = catalog
|
||||
for chart in session.query(Slice).filter_by(
|
||||
datasource_id=table.id,
|
||||
datasource_type="table",
|
||||
):
|
||||
chart.catalog_perm = catalog_perm
|
||||
chart.schema_perm = schema_perm
|
||||
|
||||
for table in session.query(SqlaTable).filter_by(database_id=database.id):
|
||||
schema_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
table.schema,
|
||||
)
|
||||
table.schema_perm = schema_perm
|
||||
for chart in session.query(Slice).filter_by(
|
||||
datasource_id=table.id,
|
||||
datasource_type="table",
|
||||
):
|
||||
chart.schema_perm = schema_perm
|
||||
# add any new catalogs discovered and their schemas
|
||||
new_catalog_pvms = add_non_default_catalogs(database, default_catalog, session)
|
||||
pvms.update(new_catalog_pvms)
|
||||
|
||||
session.commit()
|
||||
# add default catalog permission and permissions for any new found schemas, and also
|
||||
# permissions for new catalogs and their schemas
|
||||
add_pvms(session, pvms)
|
||||
|
||||
|
||||
def upgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
|
||||
def add_non_default_catalogs(
|
||||
database: Database,
|
||||
default_catalog: str,
|
||||
session: Session,
|
||||
) -> dict[str, tuple[str]]:
|
||||
"""
|
||||
Add permissions for additional catalogs and their schemas.
|
||||
"""
|
||||
try:
|
||||
catalogs = {
|
||||
catalog
|
||||
for catalog in database.get_all_catalog_names()
|
||||
if catalog != default_catalog
|
||||
}
|
||||
except GenericDBException:
|
||||
# If we can't connect to the analytical DB to fetch the catalogs we should just
|
||||
# return. The catalog and schema permissions can be created later when the DB is
|
||||
# edited.
|
||||
return {}
|
||||
|
||||
pvms = {}
|
||||
for catalog in catalogs:
|
||||
perm = security_manager.get_catalog_perm(database.database_name, catalog)
|
||||
pvms[perm] = ("catalog_access",)
|
||||
|
||||
new_schema_pvms = create_schema_perms(database, catalog, session)
|
||||
pvms.update(new_schema_pvms)
|
||||
|
||||
return pvms
|
||||
|
||||
|
||||
def upgrade_schema_perms(
|
||||
database: Database,
|
||||
default_catalog: str,
|
||||
session: Session,
|
||||
) -> dict[str, tuple[str]]:
|
||||
"""
|
||||
Rename existing schema permissions to include the catalog.
|
||||
"""
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
try:
|
||||
schemas = database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
schemas = get_schemas(database.database_name)
|
||||
|
||||
Schema permissions are stored (and processed) as strings, in the form:
|
||||
|
||||
[database_name].[schema_name]
|
||||
|
||||
When catalogs are first introduced for a DB engine spec we need to rename any
|
||||
existing permissions to the form:
|
||||
|
||||
[database_name].[default_catalog_name].[schema_name]
|
||||
|
||||
"""
|
||||
schemas = get_known_schemas(database.database_name, session)
|
||||
|
||||
perms = {}
|
||||
for schema in schemas:
|
||||
perm = security_manager.get_schema_perm(
|
||||
current_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
|
||||
if existing_pvm:
|
||||
existing_pvm.name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
new_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
default_catalog,
|
||||
schema,
|
||||
)
|
||||
|
||||
if (
|
||||
existing_pvm := session.query(ViewMenu)
|
||||
.filter_by(name=current_perm)
|
||||
.one_or_none()
|
||||
):
|
||||
existing_pvm.name = new_perm
|
||||
else:
|
||||
# new schema discovered, need to create a new permission
|
||||
perms[new_perm] = ("schema_access",)
|
||||
|
||||
return perms
|
||||
|
||||
|
||||
def create_schema_perms(
|
||||
database: Database,
|
||||
catalog: str,
|
||||
session: Session,
|
||||
) -> dict[str, tuple[str]]:
|
||||
"""
|
||||
Create schema permissions for a given catalog.
|
||||
"""
|
||||
try:
|
||||
schemas = database.get_all_schema_names(catalog=catalog)
|
||||
except GenericDBException:
|
||||
# If we can't connect to the analytical DB to fetch schemas in this catalog we
|
||||
# should just return. The schema permissions can be created when the DB is
|
||||
# edited.
|
||||
return {}
|
||||
|
||||
return {
|
||||
security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
): ("schema_access",)
|
||||
for schema in schemas
|
||||
}
|
||||
|
||||
|
||||
def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
||||
"""
|
||||
Reverse the process of `upgrade_catalog_perms`.
|
||||
|
||||
This should:
|
||||
|
||||
- Delete all `catalog_access` permissions.
|
||||
- Rename `schema_access` permissions in the default catalog to omit it.
|
||||
- Delete `schema_access` permissions for schemas not in the default catalog.
|
||||
|
||||
Also, for models in the default catalog we should:
|
||||
|
||||
- Populate the `catalog` field with `None`.
|
||||
- Update `schema_perm` to omit the default catalog.
|
||||
- Populate the `catalog_perm` field with `None`.
|
||||
|
||||
WARNING: models (datasets and charts) not in the default catalog are deleted!
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
session = db.Session(bind=bind)
|
||||
|
||||
for database in session.query(Database).all():
|
||||
db_engine_spec = database.db_engine_spec
|
||||
if (
|
||||
|
|
@ -210,70 +344,155 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
|||
) or not db_engine_spec.supports_catalog:
|
||||
continue
|
||||
|
||||
catalog = database.get_default_catalog()
|
||||
if catalog is None:
|
||||
if default_catalog := database.get_default_catalog():
|
||||
downgrade_database_catalogs(database, default_catalog, session)
|
||||
|
||||
session.flush()
|
||||
|
||||
|
||||
def downgrade_database_catalogs(
|
||||
database: Database,
|
||||
default_catalog: str,
|
||||
session: Session,
|
||||
) -> None:
|
||||
# remove all catalog permissions associated with the DB
|
||||
prefix = f"[{database.database_name}].%"
|
||||
for pvm in (
|
||||
session.query(PermissionView)
|
||||
.join(Permission, PermissionView.permission_id == Permission.id)
|
||||
.join(ViewMenu, PermissionView.view_menu_id == ViewMenu.id)
|
||||
.filter(
|
||||
Permission.name == "catalog_access",
|
||||
ViewMenu.name.like(prefix),
|
||||
)
|
||||
.all()
|
||||
):
|
||||
session.delete(pvm)
|
||||
session.delete(pvm.view_menu)
|
||||
|
||||
# rename existing schemas permissions to omit the catalog, and remove schema
|
||||
# permissions associated with other catalogs
|
||||
downgrade_schema_perms(database, default_catalog, session)
|
||||
|
||||
# update existing models
|
||||
models = [
|
||||
(Query, "database_id"),
|
||||
(SavedQuery, "db_id"),
|
||||
(TabState, "database_id"),
|
||||
(TableSchema, "database_id"),
|
||||
]
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id,
|
||||
model.catalog == default_catalog, # type: ignore
|
||||
):
|
||||
instance.catalog = None
|
||||
|
||||
# update `schema_perm` for tables and charts
|
||||
for table in session.query(SqlaTable).filter_by(
|
||||
database_id=database.id,
|
||||
catalog=default_catalog,
|
||||
):
|
||||
schema_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
table.schema,
|
||||
)
|
||||
|
||||
table.catalog = None
|
||||
table.catalog_perm = None
|
||||
table.schema_perm = schema_perm
|
||||
|
||||
for chart in session.query(Slice).filter_by(
|
||||
datasource_id=table.id,
|
||||
datasource_type="table",
|
||||
):
|
||||
chart.catalog_perm = None
|
||||
chart.schema_perm = schema_perm
|
||||
|
||||
# delete models referencing non-default catalogs
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id,
|
||||
model.catalog != default_catalog, # type: ignore
|
||||
):
|
||||
session.delete(instance)
|
||||
|
||||
# delete datasets and any associated permissions
|
||||
for table in session.query(SqlaTable).filter(
|
||||
SqlaTable.database_id == database.id,
|
||||
SqlaTable.catalog != default_catalog,
|
||||
):
|
||||
for chart in session.query(Slice).filter(
|
||||
Slice.datasource_id == table.id,
|
||||
Slice.datasource_type == "table",
|
||||
):
|
||||
session.delete(chart)
|
||||
|
||||
session.delete(table)
|
||||
pvm = (
|
||||
session.query(PermissionView)
|
||||
.join(Permission, PermissionView.permission_id == Permission.id)
|
||||
.join(ViewMenu, PermissionView.view_menu_id == ViewMenu.id)
|
||||
.filter(
|
||||
Permission.name == "datasource_access",
|
||||
ViewMenu.name == table.perm,
|
||||
)
|
||||
.one()
|
||||
)
|
||||
session.delete(pvm)
|
||||
session.delete(pvm.view_menu)
|
||||
|
||||
session.flush()
|
||||
|
||||
|
||||
def downgrade_schema_perms(
|
||||
database: Database,
|
||||
default_catalog: str,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Rename default catalog schema permissions and delete other schema permissions.
|
||||
"""
|
||||
prefix = f"[{database.database_name}].%"
|
||||
pvms = (
|
||||
session.query(PermissionView)
|
||||
.join(Permission, PermissionView.permission_id == Permission.id)
|
||||
.join(ViewMenu, PermissionView.view_menu_id == ViewMenu.id)
|
||||
.filter(
|
||||
Permission.name == "schema_access",
|
||||
ViewMenu.name.like(prefix),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
pvms_to_delete = []
|
||||
pvms_to_rename = []
|
||||
for pvm in pvms:
|
||||
parts = pvm.view_menu.name[1:-1].split("].[")
|
||||
if len(parts) != 3:
|
||||
logger.warning(
|
||||
"Invalid schema permission: %s. Please fix manually",
|
||||
pvm.view_menu.name,
|
||||
)
|
||||
continue
|
||||
|
||||
downgrade_schema_perms(database, catalog, session)
|
||||
database_name, catalog, schema = parts
|
||||
|
||||
# update existing models
|
||||
models = [
|
||||
(Query, "database_id"),
|
||||
(SavedQuery, "db_id"),
|
||||
(TabState, "database_id"),
|
||||
(TableSchema, "database_id"),
|
||||
(SqlaTable, "database_id"),
|
||||
]
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id
|
||||
):
|
||||
instance.catalog = None
|
||||
|
||||
for table in session.query(SqlaTable).filter_by(database_id=database.id):
|
||||
schema_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
table.schema,
|
||||
)
|
||||
table.schema_perm = schema_perm
|
||||
for chart in session.query(Slice).filter_by(
|
||||
datasource_id=table.id,
|
||||
datasource_type="table",
|
||||
):
|
||||
chart.schema_perm = schema_perm
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade_schema_perms(database: Database, catalog: str, session: Session) -> None:
|
||||
"""
|
||||
Rename existing schema permissions to omit the catalog.
|
||||
"""
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
try:
|
||||
schemas = database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
schemas = get_schemas(database.database_name)
|
||||
|
||||
for schema in schemas:
|
||||
perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = session.query(ViewMenu).filter_by(name=perm).one_or_none()
|
||||
if existing_pvm:
|
||||
new_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
if catalog == default_catalog:
|
||||
new_name = security_manager.get_schema_perm(
|
||||
database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
if pvm := session.query(ViewMenu).filter_by(name=new_perm).one_or_none():
|
||||
session.delete(pvm)
|
||||
session.flush()
|
||||
existing_pvm.name = new_perm
|
||||
pvms_to_rename.append((pvm, new_name))
|
||||
else:
|
||||
# non-default catalog, delete schema perm
|
||||
pvms_to_delete.append(pvm)
|
||||
|
||||
for pvm in pvms_to_delete:
|
||||
session.delete(pvm)
|
||||
session.delete(pvm.view_menu)
|
||||
|
||||
for pvm, new_name in pvms_to_rename:
|
||||
pvm.view_menu.name = new_name
|
||||
|
|
|
|||
|
|
@ -22,17 +22,18 @@ from superset.migrations.shared.catalogs import (
|
|||
downgrade_catalog_perms,
|
||||
upgrade_catalog_perms,
|
||||
)
|
||||
from superset.migrations.shared.security_converge import ViewMenu
|
||||
from superset.migrations.shared.security_converge import (
|
||||
Permission,
|
||||
PermissionView,
|
||||
ViewMenu,
|
||||
)
|
||||
|
||||
|
||||
def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
||||
"""
|
||||
Test the `upgrade_catalog_perms` function.
|
||||
|
||||
The function is called when catalogs are introduced into a new DB engine spec. When
|
||||
that happens, we need to update the `catalog` attribute so it points to the default
|
||||
catalog, instead of being `NULL`. We also need to update `schema_perms` to include
|
||||
the default catalog.
|
||||
The function is called when catalogs are introduced into a new DB engine spec.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
|
|
@ -51,6 +52,11 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
|||
"get_all_schema_names",
|
||||
return_value=["public", "information_schema"],
|
||||
)
|
||||
mocker.patch.object(
|
||||
Database,
|
||||
"get_all_catalog_names",
|
||||
return_value=["db", "other_catalog"],
|
||||
)
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
|
|
@ -61,6 +67,7 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
|||
database=database,
|
||||
catalog=None,
|
||||
schema="public",
|
||||
catalog_perm=None,
|
||||
schema_perm="[my_db].[public]",
|
||||
)
|
||||
session.add(dataset)
|
||||
|
|
@ -70,6 +77,8 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
|||
slice_name="my_chart",
|
||||
datasource_type="table",
|
||||
datasource_id=dataset.id,
|
||||
catalog_perm=None,
|
||||
schema_perm="[my_db].[public]",
|
||||
)
|
||||
query = Query(
|
||||
client_id="foo",
|
||||
|
|
@ -102,15 +111,43 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
|||
assert saved_query.catalog is None
|
||||
assert tab_state.catalog is None
|
||||
assert table_schema.catalog is None
|
||||
assert dataset.catalog_perm is None
|
||||
assert dataset.schema_perm == "[my_db].[public]"
|
||||
assert chart.catalog_perm is None
|
||||
assert chart.schema_perm == "[my_db].[public]"
|
||||
assert session.query(ViewMenu.name).all() == [
|
||||
("[my_db].(id:1)",),
|
||||
("[my_db].[my_table](id:1)",),
|
||||
("[my_db].[public]",),
|
||||
assert (
|
||||
session.query(ViewMenu.name, Permission.name)
|
||||
.join(PermissionView, ViewMenu.id == PermissionView.view_menu_id)
|
||||
.join(Permission, PermissionView.permission_id == Permission.id)
|
||||
.all()
|
||||
) == [
|
||||
("[my_db].(id:1)", "database_access"),
|
||||
("[my_db].[my_table](id:1)", "datasource_access"),
|
||||
("[my_db].[public]", "schema_access"),
|
||||
]
|
||||
|
||||
upgrade_catalog_perms()
|
||||
session.commit()
|
||||
|
||||
# add dataset/chart in new catalog
|
||||
new_dataset = SqlaTable(
|
||||
table_name="my_table",
|
||||
database=database,
|
||||
catalog="other_catalog",
|
||||
schema="public",
|
||||
schema_perm="[my_db].[other_catalog].[public]",
|
||||
catalog_perm="[my_db].[other_catalog]",
|
||||
)
|
||||
session.add(new_dataset)
|
||||
session.commit()
|
||||
|
||||
new_chart = Slice(
|
||||
slice_name="my_chart",
|
||||
datasource_type="table",
|
||||
datasource_id=new_dataset.id,
|
||||
)
|
||||
session.add(new_chart)
|
||||
session.commit()
|
||||
|
||||
# after migration
|
||||
assert dataset.catalog == "db"
|
||||
|
|
@ -118,16 +155,29 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
|||
assert saved_query.catalog == "db"
|
||||
assert tab_state.catalog == "db"
|
||||
assert table_schema.catalog == "db"
|
||||
assert dataset.catalog_perm == "[my_db].[db]"
|
||||
assert dataset.schema_perm == "[my_db].[db].[public]"
|
||||
assert chart.catalog_perm == "[my_db].[db]"
|
||||
assert chart.schema_perm == "[my_db].[db].[public]"
|
||||
assert session.query(ViewMenu.name).all() == [
|
||||
("[my_db].(id:1)",),
|
||||
("[my_db].[my_table](id:1)",),
|
||||
("[my_db].[db].[public]",),
|
||||
("[my_db].[db]",),
|
||||
assert (
|
||||
session.query(ViewMenu.name, Permission.name)
|
||||
.join(PermissionView, ViewMenu.id == PermissionView.view_menu_id)
|
||||
.join(Permission, PermissionView.permission_id == Permission.id)
|
||||
.all()
|
||||
) == [
|
||||
("[my_db].(id:1)", "database_access"),
|
||||
("[my_db].[my_table](id:1)", "datasource_access"),
|
||||
("[my_db].[db].[public]", "schema_access"),
|
||||
("[my_db].[db]", "catalog_access"),
|
||||
("[my_db].[other_catalog]", "catalog_access"),
|
||||
("[my_db].[other_catalog].[public]", "schema_access"),
|
||||
("[my_db].[other_catalog].[information_schema]", "schema_access"),
|
||||
("[my_db].[my_table](id:2)", "datasource_access"),
|
||||
]
|
||||
|
||||
# do a downgrade
|
||||
downgrade_catalog_perms()
|
||||
session.commit()
|
||||
|
||||
# revert
|
||||
assert dataset.catalog is None
|
||||
|
|
@ -135,15 +185,25 @@ def test_upgrade_catalog_perms(mocker: MockerFixture, session: Session) -> None:
|
|||
assert saved_query.catalog is None
|
||||
assert tab_state.catalog is None
|
||||
assert table_schema.catalog is None
|
||||
assert dataset.catalog_perm is None
|
||||
assert dataset.schema_perm == "[my_db].[public]"
|
||||
assert chart.catalog_perm is None
|
||||
assert chart.schema_perm == "[my_db].[public]"
|
||||
assert session.query(ViewMenu.name).all() == [
|
||||
("[my_db].(id:1)",),
|
||||
("[my_db].[my_table](id:1)",),
|
||||
("[my_db].[public]",),
|
||||
("[my_db].[db]",),
|
||||
assert (
|
||||
session.query(ViewMenu.name, Permission.name)
|
||||
.join(PermissionView, ViewMenu.id == PermissionView.view_menu_id)
|
||||
.join(Permission, PermissionView.permission_id == Permission.id)
|
||||
.all()
|
||||
) == [
|
||||
("[my_db].(id:1)", "database_access"),
|
||||
("[my_db].[my_table](id:1)", "datasource_access"),
|
||||
("[my_db].[public]", "schema_access"),
|
||||
]
|
||||
|
||||
# make sure new dataset/chart were deleted
|
||||
assert session.query(SqlaTable).all() == [dataset]
|
||||
assert session.query(Slice).all() == [chart]
|
||||
|
||||
|
||||
def test_upgrade_catalog_perms_graceful(
|
||||
mocker: MockerFixture,
|
||||
|
|
@ -236,6 +296,7 @@ def test_upgrade_catalog_perms_graceful(
|
|||
]
|
||||
|
||||
upgrade_catalog_perms()
|
||||
session.commit()
|
||||
|
||||
# after migration
|
||||
assert dataset.catalog == "db"
|
||||
|
|
@ -253,6 +314,7 @@ def test_upgrade_catalog_perms_graceful(
|
|||
]
|
||||
|
||||
downgrade_catalog_perms()
|
||||
session.commit()
|
||||
|
||||
# revert
|
||||
assert dataset.catalog is None
|
||||
|
|
@ -266,5 +328,4 @@ def test_upgrade_catalog_perms_graceful(
|
|||
("[my_db].(id:1)",),
|
||||
("[my_db].[my_table](id:1)",),
|
||||
("[my_db].[public]",),
|
||||
("[my_db].[db]",),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue