fix: upgrade_catalog_perms and downgrade_catalog_perms implementation (#29860)

This commit is contained in:
Michael S. Molina 2024-08-16 08:39:36 -04:00 committed by GitHub
parent 47715c39d0
commit e8f5d7680f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 269 additions and 91 deletions

View File

@ -18,7 +18,8 @@
from __future__ import annotations
import logging
from typing import Any, Type
from datetime import datetime
from typing import Any, Type, Union
import sqlalchemy as sa
from alembic import op
@ -35,8 +36,7 @@ from superset.migrations.shared.security_converge import (
)
from superset.models.core import Database
logger = logging.getLogger(__name__)
logger = logging.getLogger("alembic")
Base: Type[Any] = declarative_base()
@ -95,6 +95,16 @@ class Slice(Base):
schema_perm = sa.Column(sa.String(1000))
ModelType = Union[Type[Query], Type[SavedQuery], Type[TabState], Type[TableSchema]]
MODELS: list[tuple[ModelType, str]] = [
(Query, "database_id"),
(SavedQuery, "db_id"),
(TabState, "database_id"),
(TableSchema, "database_id"),
]
def get_known_schemas(database_name: str, session: Session) -> list[str]:
"""
Read all known schemas from the existing schema permissions.
@ -112,6 +122,234 @@ def get_known_schemas(database_name: str, session: Session) -> list[str]:
return sorted({name[0][1:-1].split("].[")[-1] for name in names})
def get_batch_size(session: Session) -> int:
max_sqlite_in = 999
return max_sqlite_in if session.bind.dialect.name == "sqlite" else 1_000_000
def print_processed_batch(
start_time: datetime,
offset: int,
total_rows: int,
model: ModelType,
batch_size: int,
) -> None:
"""
Print the progress of batch processing.
This function logs the progress of processing a batch of rows from a model.
It calculates the elapsed time since the start of the batch processing and
logs the number of rows processed along with the percentage completion.
Parameters:
start_time (datetime): The start time of the batch processing.
offset (int): The current offset in the batch processing.
total_rows (int): The total number of rows to process.
model (ModelType): The model being processed.
batch_size (int): The size of the batch being processed.
"""
elapsed_time = datetime.now() - start_time
elapsed_seconds = elapsed_time.total_seconds()
elapsed_formatted = f"{int(elapsed_seconds // 3600):02}:{int((elapsed_seconds % 3600) // 60):02}:{int(elapsed_seconds % 60):02}"
rows_processed = min(offset + batch_size, total_rows)
logger.info(
f"{elapsed_formatted} - {rows_processed:,} of {total_rows:,} {model.__tablename__} rows processed "
f"({(rows_processed / total_rows) * 100:.2f}%)"
)
def update_catalog_column(
session: Session, database: Database, catalog: str, downgrade: bool = False
) -> None:
"""
Update the `catalog` column in the specified models to the given catalog.
This function iterates over a list of models defined by MODELS and updates
the `catalog` columnto the specified catalog or None depending on the downgrade
parameter. The update is performed in batches to optimize performance and reduce
memory usage.
Parameters:
session (Session): The SQLAlchemy session to use for database operations.
database (Database): The database instance containing the models to update.
catalog (Catalog): The new catalog value to set in the `catalog` column or
the default catalog if `downgrade` is True.
downgrade (bool): If True, the `catalog` column is set to None where the
catalog matches the specified catalog.
"""
start_time = datetime.now()
logger.info(f"Updating {database.database_name} models to catalog {catalog}")
for model, column in MODELS:
# Get the total number of rows that match the condition
total_rows = (
session.query(sa.func.count(model.id))
.filter(getattr(model, column) == database.id)
.filter(model.catalog == catalog if downgrade else True)
.scalar()
)
logger.info(
f"Total rows to be processed for {model.__tablename__}: {total_rows:,}"
)
batch_size = get_batch_size(session)
limit_value = min(batch_size, total_rows)
# Update in batches using row numbers
for i in range(0, total_rows, batch_size):
subquery = (
session.query(model.id)
.filter(getattr(model, column) == database.id)
.filter(model.catalog == catalog if downgrade else True)
.order_by(model.id)
.offset(i)
.limit(limit_value)
.subquery()
)
# SQLite does not support multiple-table criteria within UPDATE
if session.bind.dialect.name == "sqlite":
ids_to_update = [row.id for row in session.query(subquery.c.id).all()]
if ids_to_update:
session.execute(
sa.update(model)
.where(model.id.in_(ids_to_update))
.values(catalog=None if downgrade else catalog)
.execution_options(synchronize_session=False)
)
else:
session.execute(
sa.update(model)
.where(model.id == subquery.c.id)
.values(catalog=None if downgrade else catalog)
.execution_options(synchronize_session=False)
)
print_processed_batch(start_time, i, total_rows, model, batch_size)
def update_schema_catalog_perms(
session: Session,
database: Database,
catalog_perm: str | None,
catalog: str,
downgrade: bool = False,
) -> None:
"""
Update schema and catalog permissions for tables and charts in a given database.
This function updates the `catalog`, `catalog_perm`, and `schema_perm` fields for
tables and charts associated with the specified database. If `downgrade` is True,
the `catalog` and `catalog_perm` fields are set to None, otherwise they are set
to the provided `catalog` and `catalog_perm` values.
Args:
session (Session): The SQLAlchemy session to use for database operations.
database (Database): The database object whose tables and charts will be updated.
catalog_perm (str): The new catalog permission to set.
catalog (str): The new catalog to set.
downgrade (bool, optional): If True, reset the `catalog` and `catalog_perm` fields to None.
Defaults to False.
"""
# Mapping of table id to schema permission
mapping = {}
for table in (
session.query(SqlaTable)
.filter_by(database_id=database.id)
.filter_by(catalog=catalog if downgrade else None)
):
schema_perm = security_manager.get_schema_perm(
database.database_name,
None if downgrade else catalog,
table.schema,
)
table.catalog = None if downgrade else catalog
table.catalog_perm = catalog_perm
table.schema_perm = schema_perm
mapping[table.id] = schema_perm
# Select all slices of type table that belong to the database
for chart in (
session.query(Slice)
.join(SqlaTable, Slice.datasource_id == SqlaTable.id)
.join(Database, SqlaTable.database_id == Database.id)
.filter(Database.id == database.id)
.filter(Slice.datasource_type == "table")
):
# We only care about tables that exist in the mapping
if mapping.get(chart.datasource_id) is not None:
chart.catalog_perm = catalog_perm
chart.schema_perm = mapping[chart.datasource_id]
def delete_models_non_default_catalog(
session: Session, database: Database, catalog: str
) -> None:
"""
Delete models that are not in the default catalog.
This function iterates over a list of models defined by MODELS and deletes
the rows where the `catalog` column does not match the specified catalog.
Parameters:
session (Session): The SQLAlchemy session to use for database operations.
database (Database): The database instance containing the models to delete.
catalog (Catalog): The catalog to use to filter the models to delete.
"""
start_time = datetime.now()
logger.info(f"Deleting models not in the default catalog: {catalog}")
for model, column in MODELS:
# Get the total number of rows that match the condition
total_rows = (
session.query(sa.func.count(model.id))
.filter(getattr(model, column) == database.id)
.filter(model.catalog != catalog)
.scalar()
)
logger.info(
f"Total rows to be processed for {model.__tablename__}: {total_rows:,}"
)
batch_size = get_batch_size(session)
limit_value = min(batch_size, total_rows)
# Update in batches using row numbers
for i in range(0, total_rows, batch_size):
subquery = (
session.query(model.id)
.filter(getattr(model, column) == database.id)
.filter(model.catalog != catalog)
.order_by(model.id)
.offset(i)
.limit(limit_value)
.subquery()
)
# SQLite does not support multiple-table criteria within DELETE
if session.bind.dialect.name == "sqlite":
ids_to_delete = [row.id for row in session.query(subquery.c.id).all()]
if ids_to_delete:
session.execute(
sa.delete(model)
.where(model.id.in_(ids_to_delete))
.execution_options(synchronize_session=False)
)
else:
session.execute(
sa.delete(model)
.where(model.id == subquery.c.id)
.execution_options(synchronize_session=False)
)
print_processed_batch(start_time, i, total_rows, model, batch_size)
def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Update models and permissions when catalogs are introduced in a DB engine spec.
@ -157,11 +395,13 @@ def upgrade_database_catalogs(
"""
Upgrade a given database to support the default catalog.
"""
catalog_perm = security_manager.get_catalog_perm(
catalog_perm: str | None = security_manager.get_catalog_perm(
database.database_name,
default_catalog,
)
pvms: dict[str, tuple[str, ...]] = {catalog_perm: ("catalog_access",)}
pvms: dict[str, tuple[str, ...]] = (
{catalog_perm: ("catalog_access",)} if catalog_perm else {}
)
# rename existing schema permissions to include the catalog, and also find any new
# schemas
@ -170,39 +410,10 @@ def upgrade_database_catalogs(
# update existing models that have a `catalog` column so it points to the default
# catalog
models = [
(Query, "database_id"),
(SavedQuery, "db_id"),
(TabState, "database_id"),
(TableSchema, "database_id"),
]
for model, column in models:
for instance in session.query(model).filter(
getattr(model, column) == database.id
):
instance.catalog = default_catalog
update_catalog_column(session, database, default_catalog, False)
# update `schema_perm` and `catalog_perm` for tables and charts
for table in session.query(SqlaTable).filter_by(
database_id=database.id,
catalog=None,
):
schema_perm = security_manager.get_schema_perm(
database.database_name,
default_catalog,
table.schema,
)
table.catalog = default_catalog
table.catalog_perm = catalog_perm
table.schema_perm = schema_perm
for chart in session.query(Slice).filter_by(
datasource_id=table.id,
datasource_type="table",
):
chart.catalog_perm = catalog_perm
chart.schema_perm = schema_perm
update_schema_catalog_perms(session, database, catalog_perm, default_catalog, False)
# add any new catalogs discovered and their schemas
new_catalog_pvms = add_non_default_catalogs(database, default_catalog, session)
@ -233,13 +444,15 @@ def add_non_default_catalogs(
# edited.
return {}
pvms = {}
pvms: dict[str, tuple[str]] = {}
for catalog in catalogs:
perm = security_manager.get_catalog_perm(database.database_name, catalog)
pvms[perm] = ("catalog_access",)
new_schema_pvms = create_schema_perms(database, catalog, session)
pvms.update(new_schema_pvms)
perm: str | None = security_manager.get_catalog_perm(
database.database_name, catalog
)
if perm:
pvms[perm] = ("catalog_access",)
new_schema_pvms = create_schema_perms(database, catalog)
pvms.update(new_schema_pvms)
return pvms
@ -266,12 +479,12 @@ def upgrade_schema_perms(
perms = {}
for schema in schemas:
current_perm = security_manager.get_schema_perm(
current_perm: str | None = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
new_perm = security_manager.get_schema_perm(
new_perm: str | None = security_manager.get_schema_perm(
database.database_name,
default_catalog,
schema,
@ -283,7 +496,7 @@ def upgrade_schema_perms(
.one_or_none()
):
existing_pvm.name = new_perm
else:
elif new_perm:
# new schema discovered, need to create a new permission
perms[new_perm] = ("schema_access",)
@ -293,7 +506,6 @@ def upgrade_schema_perms(
def create_schema_perms(
database: Database,
catalog: str,
session: Session,
) -> dict[str, tuple[str]]:
"""
Create schema permissions for a given catalog.
@ -307,12 +519,14 @@ def create_schema_perms(
return {}
return {
security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
): ("schema_access",)
perm: ("schema_access",)
for schema in schemas
if (
perm := security_manager.get_schema_perm(
database.database_name, catalog, schema
)
)
is not None
}
@ -374,49 +588,13 @@ def downgrade_database_catalogs(
# permissions associated with other catalogs
downgrade_schema_perms(database, default_catalog, session)
# update existing models
models = [
(Query, "database_id"),
(SavedQuery, "db_id"),
(TabState, "database_id"),
(TableSchema, "database_id"),
]
for model, column in models:
for instance in session.query(model).filter(
getattr(model, column) == database.id,
model.catalog == default_catalog, # type: ignore
):
instance.catalog = None
update_catalog_column(session, database, default_catalog, True)
# update `schema_perm` for tables and charts
for table in session.query(SqlaTable).filter_by(
database_id=database.id,
catalog=default_catalog,
):
schema_perm = security_manager.get_schema_perm(
database.database_name,
None,
table.schema,
)
table.catalog = None
table.catalog_perm = None
table.schema_perm = schema_perm
for chart in session.query(Slice).filter_by(
datasource_id=table.id,
datasource_type="table",
):
chart.catalog_perm = None
chart.schema_perm = schema_perm
# update `schema_perm` and `catalog_perm` for tables and charts
update_schema_catalog_perms(session, database, None, default_catalog, True)
# delete models referencing non-default catalogs
for model, column in models:
for instance in session.query(model).filter(
getattr(model, column) == database.id,
model.catalog != default_catalog, # type: ignore
):
session.delete(instance)
delete_models_non_default_catalog(session, database, default_catalog)
# delete datasets and any associated permissions
for table in session.query(SqlaTable).filter(