fix: upgrade_catalog_perms and downgrade_catalog_perms implementation (#29860)
This commit is contained in:
parent
47715c39d0
commit
e8f5d7680f
|
|
@ -18,7 +18,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Type
|
||||
from datetime import datetime
|
||||
from typing import Any, Type, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
|
@ -35,8 +36,7 @@ from superset.migrations.shared.security_converge import (
|
|||
)
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger("alembic")
|
||||
|
||||
Base: Type[Any] = declarative_base()
|
||||
|
||||
|
|
@ -95,6 +95,16 @@ class Slice(Base):
|
|||
schema_perm = sa.Column(sa.String(1000))
|
||||
|
||||
|
||||
ModelType = Union[Type[Query], Type[SavedQuery], Type[TabState], Type[TableSchema]]
|
||||
|
||||
MODELS: list[tuple[ModelType, str]] = [
|
||||
(Query, "database_id"),
|
||||
(SavedQuery, "db_id"),
|
||||
(TabState, "database_id"),
|
||||
(TableSchema, "database_id"),
|
||||
]
|
||||
|
||||
|
||||
def get_known_schemas(database_name: str, session: Session) -> list[str]:
|
||||
"""
|
||||
Read all known schemas from the existing schema permissions.
|
||||
|
|
@ -112,6 +122,234 @@ def get_known_schemas(database_name: str, session: Session) -> list[str]:
|
|||
return sorted({name[0][1:-1].split("].[")[-1] for name in names})
|
||||
|
||||
|
||||
def get_batch_size(session: Session) -> int:
|
||||
max_sqlite_in = 999
|
||||
return max_sqlite_in if session.bind.dialect.name == "sqlite" else 1_000_000
|
||||
|
||||
|
||||
def print_processed_batch(
|
||||
start_time: datetime,
|
||||
offset: int,
|
||||
total_rows: int,
|
||||
model: ModelType,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
"""
|
||||
Print the progress of batch processing.
|
||||
|
||||
This function logs the progress of processing a batch of rows from a model.
|
||||
It calculates the elapsed time since the start of the batch processing and
|
||||
logs the number of rows processed along with the percentage completion.
|
||||
|
||||
Parameters:
|
||||
start_time (datetime): The start time of the batch processing.
|
||||
offset (int): The current offset in the batch processing.
|
||||
total_rows (int): The total number of rows to process.
|
||||
model (ModelType): The model being processed.
|
||||
batch_size (int): The size of the batch being processed.
|
||||
"""
|
||||
elapsed_time = datetime.now() - start_time
|
||||
elapsed_seconds = elapsed_time.total_seconds()
|
||||
elapsed_formatted = f"{int(elapsed_seconds // 3600):02}:{int((elapsed_seconds % 3600) // 60):02}:{int(elapsed_seconds % 60):02}"
|
||||
rows_processed = min(offset + batch_size, total_rows)
|
||||
logger.info(
|
||||
f"{elapsed_formatted} - {rows_processed:,} of {total_rows:,} {model.__tablename__} rows processed "
|
||||
f"({(rows_processed / total_rows) * 100:.2f}%)"
|
||||
)
|
||||
|
||||
|
||||
def update_catalog_column(
|
||||
session: Session, database: Database, catalog: str, downgrade: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Update the `catalog` column in the specified models to the given catalog.
|
||||
|
||||
This function iterates over a list of models defined by MODELS and updates
|
||||
the `catalog` columnto the specified catalog or None depending on the downgrade
|
||||
parameter. The update is performed in batches to optimize performance and reduce
|
||||
memory usage.
|
||||
|
||||
Parameters:
|
||||
session (Session): The SQLAlchemy session to use for database operations.
|
||||
database (Database): The database instance containing the models to update.
|
||||
catalog (Catalog): The new catalog value to set in the `catalog` column or
|
||||
the default catalog if `downgrade` is True.
|
||||
downgrade (bool): If True, the `catalog` column is set to None where the
|
||||
catalog matches the specified catalog.
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
logger.info(f"Updating {database.database_name} models to catalog {catalog}")
|
||||
|
||||
for model, column in MODELS:
|
||||
# Get the total number of rows that match the condition
|
||||
total_rows = (
|
||||
session.query(sa.func.count(model.id))
|
||||
.filter(getattr(model, column) == database.id)
|
||||
.filter(model.catalog == catalog if downgrade else True)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Total rows to be processed for {model.__tablename__}: {total_rows:,}"
|
||||
)
|
||||
|
||||
batch_size = get_batch_size(session)
|
||||
limit_value = min(batch_size, total_rows)
|
||||
|
||||
# Update in batches using row numbers
|
||||
for i in range(0, total_rows, batch_size):
|
||||
subquery = (
|
||||
session.query(model.id)
|
||||
.filter(getattr(model, column) == database.id)
|
||||
.filter(model.catalog == catalog if downgrade else True)
|
||||
.order_by(model.id)
|
||||
.offset(i)
|
||||
.limit(limit_value)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# SQLite does not support multiple-table criteria within UPDATE
|
||||
if session.bind.dialect.name == "sqlite":
|
||||
ids_to_update = [row.id for row in session.query(subquery.c.id).all()]
|
||||
if ids_to_update:
|
||||
session.execute(
|
||||
sa.update(model)
|
||||
.where(model.id.in_(ids_to_update))
|
||||
.values(catalog=None if downgrade else catalog)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
else:
|
||||
session.execute(
|
||||
sa.update(model)
|
||||
.where(model.id == subquery.c.id)
|
||||
.values(catalog=None if downgrade else catalog)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
print_processed_batch(start_time, i, total_rows, model, batch_size)
|
||||
|
||||
|
||||
def update_schema_catalog_perms(
|
||||
session: Session,
|
||||
database: Database,
|
||||
catalog_perm: str | None,
|
||||
catalog: str,
|
||||
downgrade: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Update schema and catalog permissions for tables and charts in a given database.
|
||||
|
||||
This function updates the `catalog`, `catalog_perm`, and `schema_perm` fields for
|
||||
tables and charts associated with the specified database. If `downgrade` is True,
|
||||
the `catalog` and `catalog_perm` fields are set to None, otherwise they are set
|
||||
to the provided `catalog` and `catalog_perm` values.
|
||||
|
||||
Args:
|
||||
session (Session): The SQLAlchemy session to use for database operations.
|
||||
database (Database): The database object whose tables and charts will be updated.
|
||||
catalog_perm (str): The new catalog permission to set.
|
||||
catalog (str): The new catalog to set.
|
||||
downgrade (bool, optional): If True, reset the `catalog` and `catalog_perm` fields to None.
|
||||
Defaults to False.
|
||||
"""
|
||||
# Mapping of table id to schema permission
|
||||
mapping = {}
|
||||
|
||||
for table in (
|
||||
session.query(SqlaTable)
|
||||
.filter_by(database_id=database.id)
|
||||
.filter_by(catalog=catalog if downgrade else None)
|
||||
):
|
||||
schema_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None if downgrade else catalog,
|
||||
table.schema,
|
||||
)
|
||||
table.catalog = None if downgrade else catalog
|
||||
table.catalog_perm = catalog_perm
|
||||
table.schema_perm = schema_perm
|
||||
mapping[table.id] = schema_perm
|
||||
|
||||
# Select all slices of type table that belong to the database
|
||||
for chart in (
|
||||
session.query(Slice)
|
||||
.join(SqlaTable, Slice.datasource_id == SqlaTable.id)
|
||||
.join(Database, SqlaTable.database_id == Database.id)
|
||||
.filter(Database.id == database.id)
|
||||
.filter(Slice.datasource_type == "table")
|
||||
):
|
||||
# We only care about tables that exist in the mapping
|
||||
if mapping.get(chart.datasource_id) is not None:
|
||||
chart.catalog_perm = catalog_perm
|
||||
chart.schema_perm = mapping[chart.datasource_id]
|
||||
|
||||
|
||||
def delete_models_non_default_catalog(
|
||||
session: Session, database: Database, catalog: str
|
||||
) -> None:
|
||||
"""
|
||||
Delete models that are not in the default catalog.
|
||||
|
||||
This function iterates over a list of models defined by MODELS and deletes
|
||||
the rows where the `catalog` column does not match the specified catalog.
|
||||
|
||||
Parameters:
|
||||
session (Session): The SQLAlchemy session to use for database operations.
|
||||
database (Database): The database instance containing the models to delete.
|
||||
catalog (Catalog): The catalog to use to filter the models to delete.
|
||||
"""
|
||||
start_time = datetime.now()
|
||||
|
||||
logger.info(f"Deleting models not in the default catalog: {catalog}")
|
||||
|
||||
for model, column in MODELS:
|
||||
# Get the total number of rows that match the condition
|
||||
total_rows = (
|
||||
session.query(sa.func.count(model.id))
|
||||
.filter(getattr(model, column) == database.id)
|
||||
.filter(model.catalog != catalog)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Total rows to be processed for {model.__tablename__}: {total_rows:,}"
|
||||
)
|
||||
|
||||
batch_size = get_batch_size(session)
|
||||
limit_value = min(batch_size, total_rows)
|
||||
|
||||
# Update in batches using row numbers
|
||||
for i in range(0, total_rows, batch_size):
|
||||
subquery = (
|
||||
session.query(model.id)
|
||||
.filter(getattr(model, column) == database.id)
|
||||
.filter(model.catalog != catalog)
|
||||
.order_by(model.id)
|
||||
.offset(i)
|
||||
.limit(limit_value)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# SQLite does not support multiple-table criteria within DELETE
|
||||
if session.bind.dialect.name == "sqlite":
|
||||
ids_to_delete = [row.id for row in session.query(subquery.c.id).all()]
|
||||
if ids_to_delete:
|
||||
session.execute(
|
||||
sa.delete(model)
|
||||
.where(model.id.in_(ids_to_delete))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
else:
|
||||
session.execute(
|
||||
sa.delete(model)
|
||||
.where(model.id == subquery.c.id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
print_processed_batch(start_time, i, total_rows, model, batch_size)
|
||||
|
||||
|
||||
def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
||||
"""
|
||||
Update models and permissions when catalogs are introduced in a DB engine spec.
|
||||
|
|
@ -157,11 +395,13 @@ def upgrade_database_catalogs(
|
|||
"""
|
||||
Upgrade a given database to support the default catalog.
|
||||
"""
|
||||
catalog_perm = security_manager.get_catalog_perm(
|
||||
catalog_perm: str | None = security_manager.get_catalog_perm(
|
||||
database.database_name,
|
||||
default_catalog,
|
||||
)
|
||||
pvms: dict[str, tuple[str, ...]] = {catalog_perm: ("catalog_access",)}
|
||||
pvms: dict[str, tuple[str, ...]] = (
|
||||
{catalog_perm: ("catalog_access",)} if catalog_perm else {}
|
||||
)
|
||||
|
||||
# rename existing schema permissions to include the catalog, and also find any new
|
||||
# schemas
|
||||
|
|
@ -170,39 +410,10 @@ def upgrade_database_catalogs(
|
|||
|
||||
# update existing models that have a `catalog` column so it points to the default
|
||||
# catalog
|
||||
models = [
|
||||
(Query, "database_id"),
|
||||
(SavedQuery, "db_id"),
|
||||
(TabState, "database_id"),
|
||||
(TableSchema, "database_id"),
|
||||
]
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id
|
||||
):
|
||||
instance.catalog = default_catalog
|
||||
update_catalog_column(session, database, default_catalog, False)
|
||||
|
||||
# update `schema_perm` and `catalog_perm` for tables and charts
|
||||
for table in session.query(SqlaTable).filter_by(
|
||||
database_id=database.id,
|
||||
catalog=None,
|
||||
):
|
||||
schema_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
default_catalog,
|
||||
table.schema,
|
||||
)
|
||||
|
||||
table.catalog = default_catalog
|
||||
table.catalog_perm = catalog_perm
|
||||
table.schema_perm = schema_perm
|
||||
|
||||
for chart in session.query(Slice).filter_by(
|
||||
datasource_id=table.id,
|
||||
datasource_type="table",
|
||||
):
|
||||
chart.catalog_perm = catalog_perm
|
||||
chart.schema_perm = schema_perm
|
||||
update_schema_catalog_perms(session, database, catalog_perm, default_catalog, False)
|
||||
|
||||
# add any new catalogs discovered and their schemas
|
||||
new_catalog_pvms = add_non_default_catalogs(database, default_catalog, session)
|
||||
|
|
@ -233,13 +444,15 @@ def add_non_default_catalogs(
|
|||
# edited.
|
||||
return {}
|
||||
|
||||
pvms = {}
|
||||
pvms: dict[str, tuple[str]] = {}
|
||||
for catalog in catalogs:
|
||||
perm = security_manager.get_catalog_perm(database.database_name, catalog)
|
||||
pvms[perm] = ("catalog_access",)
|
||||
|
||||
new_schema_pvms = create_schema_perms(database, catalog, session)
|
||||
pvms.update(new_schema_pvms)
|
||||
perm: str | None = security_manager.get_catalog_perm(
|
||||
database.database_name, catalog
|
||||
)
|
||||
if perm:
|
||||
pvms[perm] = ("catalog_access",)
|
||||
new_schema_pvms = create_schema_perms(database, catalog)
|
||||
pvms.update(new_schema_pvms)
|
||||
|
||||
return pvms
|
||||
|
||||
|
|
@ -266,12 +479,12 @@ def upgrade_schema_perms(
|
|||
|
||||
perms = {}
|
||||
for schema in schemas:
|
||||
current_perm = security_manager.get_schema_perm(
|
||||
current_perm: str | None = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
new_perm = security_manager.get_schema_perm(
|
||||
new_perm: str | None = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
default_catalog,
|
||||
schema,
|
||||
|
|
@ -283,7 +496,7 @@ def upgrade_schema_perms(
|
|||
.one_or_none()
|
||||
):
|
||||
existing_pvm.name = new_perm
|
||||
else:
|
||||
elif new_perm:
|
||||
# new schema discovered, need to create a new permission
|
||||
perms[new_perm] = ("schema_access",)
|
||||
|
||||
|
|
@ -293,7 +506,6 @@ def upgrade_schema_perms(
|
|||
def create_schema_perms(
|
||||
database: Database,
|
||||
catalog: str,
|
||||
session: Session,
|
||||
) -> dict[str, tuple[str]]:
|
||||
"""
|
||||
Create schema permissions for a given catalog.
|
||||
|
|
@ -307,12 +519,14 @@ def create_schema_perms(
|
|||
return {}
|
||||
|
||||
return {
|
||||
security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
): ("schema_access",)
|
||||
perm: ("schema_access",)
|
||||
for schema in schemas
|
||||
if (
|
||||
perm := security_manager.get_schema_perm(
|
||||
database.database_name, catalog, schema
|
||||
)
|
||||
)
|
||||
is not None
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -374,49 +588,13 @@ def downgrade_database_catalogs(
|
|||
# permissions associated with other catalogs
|
||||
downgrade_schema_perms(database, default_catalog, session)
|
||||
|
||||
# update existing models
|
||||
models = [
|
||||
(Query, "database_id"),
|
||||
(SavedQuery, "db_id"),
|
||||
(TabState, "database_id"),
|
||||
(TableSchema, "database_id"),
|
||||
]
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id,
|
||||
model.catalog == default_catalog, # type: ignore
|
||||
):
|
||||
instance.catalog = None
|
||||
update_catalog_column(session, database, default_catalog, True)
|
||||
|
||||
# update `schema_perm` for tables and charts
|
||||
for table in session.query(SqlaTable).filter_by(
|
||||
database_id=database.id,
|
||||
catalog=default_catalog,
|
||||
):
|
||||
schema_perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
table.schema,
|
||||
)
|
||||
|
||||
table.catalog = None
|
||||
table.catalog_perm = None
|
||||
table.schema_perm = schema_perm
|
||||
|
||||
for chart in session.query(Slice).filter_by(
|
||||
datasource_id=table.id,
|
||||
datasource_type="table",
|
||||
):
|
||||
chart.catalog_perm = None
|
||||
chart.schema_perm = schema_perm
|
||||
# update `schema_perm` and `catalog_perm` for tables and charts
|
||||
update_schema_catalog_perms(session, database, None, default_catalog, True)
|
||||
|
||||
# delete models referencing non-default catalogs
|
||||
for model, column in models:
|
||||
for instance in session.query(model).filter(
|
||||
getattr(model, column) == database.id,
|
||||
model.catalog != default_catalog, # type: ignore
|
||||
):
|
||||
session.delete(instance)
|
||||
delete_models_non_default_catalog(session, database, default_catalog)
|
||||
|
||||
# delete datasets and any associated permissions
|
||||
for table in session.query(SqlaTable).filter(
|
||||
|
|
|
|||
Loading…
Reference in New Issue