diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py
index 4903938eb..b45107ca8 100644
--- a/superset/commands/database/create.py
+++ b/superset/commands/database/create.py
@@ -97,12 +97,37 @@ class CreateDatabaseCommand(BaseCommand):
db.session.commit()
- # adding a new database we always want to force refresh schema list
- schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
- for schema in schemas:
- security_manager.add_permission_view_menu(
- "schema_access", security_manager.get_schema_perm(database, schema)
+ # add catalog/schema permissions
+ if database.db_engine_spec.supports_catalog:
+ catalogs = database.get_all_catalog_names(
+ cache=False,
+ ssh_tunnel=ssh_tunnel,
)
+ for catalog in catalogs:
+ security_manager.add_permission_view_menu(
+ "catalog_access",
+ security_manager.get_catalog_perm(
+ database.database_name, catalog
+ ),
+ )
+ else:
+ # add a dummy catalog for DBs that don't support them
+ catalogs = [None]
+
+ for catalog in catalogs:
+ for schema in database.get_all_schema_names(
+ catalog=catalog,
+ cache=False,
+ ssh_tunnel=ssh_tunnel,
+ ):
+ security_manager.add_permission_view_menu(
+ "schema_access",
+ security_manager.get_schema_perm(
+ database.database_name,
+ catalog,
+ schema,
+ ),
+ )
except (
SSHTunnelInvalidError,
diff --git a/superset/commands/database/tables.py b/superset/commands/database/tables.py
index 055c0be9a..b16fcfc50 100644
--- a/superset/commands/database/tables.py
+++ b/superset/commands/database/tables.py
@@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+from __future__ import annotations
+
import logging
from typing import Any, cast
@@ -29,7 +32,6 @@ from superset.daos.database import DatabaseDAO
from superset.exceptions import SupersetException
from superset.extensions import db, security_manager
from superset.models.core import Database
-from superset.utils.core import DatasourceName
logger = logging.getLogger(__name__)
@@ -37,8 +39,15 @@ logger = logging.getLogger(__name__)
class TablesDatabaseCommand(BaseCommand):
_model: Database
- def __init__(self, db_id: int, schema_name: str, force: bool):
+ def __init__(
+ self,
+ db_id: int,
+ catalog_name: str | None,
+ schema_name: str,
+ force: bool,
+ ):
self._db_id = db_id
+ self._catalog_name = catalog_name
self._schema_name = schema_name
self._force = force
@@ -47,11 +56,11 @@ class TablesDatabaseCommand(BaseCommand):
try:
tables = security_manager.get_datasources_accessible_by_user(
database=self._model,
+ catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
- DatasourceName(*datasource_name)
- for datasource_name in self._model.get_all_table_names_in_schema(
- catalog=None,
+ self._model.get_all_table_names_in_schema(
+ catalog=self._catalog_name,
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
@@ -62,11 +71,11 @@ class TablesDatabaseCommand(BaseCommand):
views = security_manager.get_datasources_accessible_by_user(
database=self._model,
+ catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
- DatasourceName(*datasource_name)
- for datasource_name in self._model.get_all_view_names_in_schema(
- catalog=None,
+ self._model.get_all_view_names_in_schema(
+ catalog=self._catalog_name,
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
@@ -81,11 +90,15 @@ class TablesDatabaseCommand(BaseCommand):
db.session.query(SqlaTable)
.filter(
SqlaTable.database_id == self._model.id,
+ SqlaTable.catalog == self._catalog_name,
SqlaTable.schema == self._schema_name,
)
.options(
load_only(
- SqlaTable.schema, SqlaTable.table_name, SqlaTable.extra
+ SqlaTable.catalog,
+ SqlaTable.schema,
+ SqlaTable.table_name,
+ SqlaTable.extra,
),
lazyload(SqlaTable.columns),
lazyload(SqlaTable.metrics),
diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py
index 5e0968954..c59984238 100644
--- a/superset/commands/database/update.py
+++ b/superset/commands/database/update.py
@@ -18,10 +18,9 @@
from __future__ import annotations
import logging
-from typing import Any, Optional
+from typing import Any
from flask_appbuilder.models.sqla import Model
-from marshmallow import ValidationError
from superset import is_feature_enabled, security_manager
from superset.commands.base import BaseCommand
@@ -50,12 +49,12 @@ logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand):
- _model: Optional[Database]
+ _model: Database | None
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
- self._model: Optional[Database] = None
+ self._model: Database | None = None
def run(self) -> Model:
self._model = DatabaseDAO.find_by_id(self._model_id)
@@ -85,7 +84,7 @@ class UpdateDatabaseCommand(BaseCommand):
)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
- self._refresh_schemas(database, original_database_name, ssh_tunnel)
+ self._refresh_catalogs(database, original_database_name, ssh_tunnel)
except SSHTunnelError as ex:
# allow exception to bubble for debugbing information
raise ex
@@ -121,67 +120,200 @@ class UpdateDatabaseCommand(BaseCommand):
ssh_tunnel_properties,
).run()
- def _refresh_schemas(
+ def _get_catalog_names(
self,
database: Database,
- original_database_name: str,
- ssh_tunnel: Optional[SSHTunnel],
- ) -> None:
+ ssh_tunnel: SSHTunnel | None,
+ ) -> set[str]:
"""
- Add permissions for any new schemas.
+ Helper method to load catalogs.
+
+ This method captures a generic exception, since errors could potentially come
+ from any of the 50+ database drivers we support.
"""
try:
- schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
+ return database.get_all_catalog_names(
+ force=True,
+ ssh_tunnel=ssh_tunnel,
+ )
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
- for schema in schemas:
- original_vm = security_manager.get_schema_perm(
+ def _get_schema_names(
+ self,
+ database: Database,
+ catalog: str | None,
+ ssh_tunnel: SSHTunnel | None,
+ ) -> set[str]:
+ """
+ Helper method to load schemas.
+
+ This method captures a generic exception, since errors could potentially come
+ from any of the 50+ database drivers we support.
+ """
+ try:
+ return database.get_all_schema_names(
+ force=True,
+ catalog=catalog,
+ ssh_tunnel=ssh_tunnel,
+ )
+ except Exception as ex:
+ db.session.rollback()
+ raise DatabaseConnectionFailedError() from ex
+
+ def _refresh_catalogs(
+ self,
+ database: Database,
+ original_database_name: str,
+ ssh_tunnel: SSHTunnel | None,
+ ) -> None:
+ """
+ Add permissions for any new catalogs and schemas.
+ """
+ catalogs = (
+ self._get_catalog_names(database, ssh_tunnel)
+ if database.db_engine_spec.supports_catalog
+ else [None]
+ )
+
+ for catalog in catalogs:
+ schemas = self._get_schema_names(database, catalog, ssh_tunnel)
+
+ if catalog:
+ perm = security_manager.get_catalog_perm(
+ original_database_name,
+ catalog,
+ )
+ existing_pvm = security_manager.find_permission_view_menu(
+ "catalog_access",
+ perm,
+ )
+ if not existing_pvm:
+ # new catalog
+ security_manager.add_permission_view_menu(
+ "catalog_access",
+ security_manager.get_catalog_perm(
+ database.database_name,
+ catalog,
+ ),
+ )
+ for schema in schemas:
+ security_manager.add_permission_view_menu(
+ "schema_access",
+ security_manager.get_schema_perm(
+ database.database_name,
+ catalog,
+ schema,
+ ),
+ )
+ continue
+
+ # add possible new schemas in catalog
+ self._refresh_schemas(
+ database,
original_database_name,
+ catalog,
+ schemas,
+ )
+
+ if original_database_name != database.database_name:
+ self._rename_database_in_permissions(
+ database,
+ original_database_name,
+ catalog,
+ schemas,
+ )
+
+ db.session.commit()
+
+ def _refresh_schemas(
+ self,
+ database: Database,
+ original_database_name: str,
+ catalog: str | None,
+ schemas: set[str],
+ ) -> None:
+ """
+ Add new schemas that don't have permissions yet.
+ """
+ for schema in schemas:
+ perm = security_manager.get_schema_perm(
+ original_database_name,
+ catalog,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
- original_vm,
+ perm,
)
if not existing_pvm:
- # new schema
- security_manager.add_permission_view_menu(
- "schema_access",
- security_manager.get_schema_perm(database.database_name, schema),
+ new_name = security_manager.get_schema_perm(
+ database.database_name,
+ catalog,
+ schema,
)
- continue
+ security_manager.add_permission_view_menu("schema_access", new_name)
- if original_database_name == database.database_name:
- continue
+ def _rename_database_in_permissions(
+ self,
+ database: Database,
+ original_database_name: str,
+ catalog: str | None,
+ schemas: set[str],
+ ) -> None:
+ new_name = security_manager.get_catalog_perm(
+ database.database_name,
+ catalog,
+ )
- # rename existing schema permission
- existing_pvm.view_menu.name = security_manager.get_schema_perm(
+ # rename existing catalog permission
+ if catalog:
+ perm = security_manager.get_catalog_perm(
+ original_database_name,
+ catalog,
+ )
+ existing_pvm = security_manager.find_permission_view_menu(
+ "catalog_access",
+ perm,
+ )
+ if existing_pvm:
+ existing_pvm.view_menu.name = new_name
+
+ for schema in schemas:
+ new_name = security_manager.get_schema_perm(
database.database_name,
+ catalog,
schema,
)
+ # rename existing schema permission
+ perm = security_manager.get_schema_perm(
+ original_database_name,
+ catalog,
+ schema,
+ )
+ existing_pvm = security_manager.find_permission_view_menu(
+ "schema_access",
+ perm,
+ )
+ if existing_pvm:
+ existing_pvm.view_menu.name = new_name
+
# rename permissions on datasets and charts
for dataset in DatabaseDAO.get_datasets(
database.id,
- catalog=None,
+ catalog=catalog,
schema=schema,
):
- dataset.schema_perm = existing_pvm.view_menu.name
+ dataset.schema_perm = new_name
for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]:
- chart.schema_perm = existing_pvm.view_menu.name
-
- db.session.commit()
+ chart.schema_perm = new_name
def validate(self) -> None:
- exceptions: list[ValidationError] = []
- database_name: Optional[str] = self._properties.get("database_name")
- if database_name:
- # Check database_name uniqueness
+ if database_name := self._properties.get("database_name"):
if not DatabaseDAO.validate_update_uniqueness(
- self._model_id, database_name
+ self._model_id,
+ database_name,
):
- exceptions.append(DatabaseExistsValidationError())
- if exceptions:
- raise DatabaseInvalidError(exceptions=exceptions)
+ raise DatabaseInvalidError(exceptions=[DatabaseExistsValidationError()])
diff --git a/superset/commands/sql_lab/export.py b/superset/commands/sql_lab/export.py
index aa6050f27..bfa739054 100644
--- a/superset/commands/sql_lab/export.py
+++ b/superset/commands/sql_lab/export.py
@@ -126,7 +126,11 @@ class SqlResultExportCommand(BaseCommand):
}:
# remove extra row from `increased_limit`
limit -= 1
- df = self._query.database.get_df(sql, self._query.schema)[:limit]
+ df = self._query.database.get_df(
+ sql,
+ self._query.catalog,
+ self._query.schema,
+ )[:limit]
csv_data = csv.df_to_escaped_csv(df, index=False, **config["CSV_EXPORT"])
diff --git a/superset/config.py b/superset/config.py
index 9388edbe8..7851938b7 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -564,9 +564,9 @@ IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None
#
# Takes as a parameter the common bootstrap payload before transformations.
# Returns a dict containing data that should be added or overridden to the payload.
-COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[[dict[str, Any]], dict[str, Any]] = ( # noqa: E731
- lambda data: {}
-) # default: empty dict
+COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[ # noqa: E731
+ [dict[str, Any]], dict[str, Any]
+] = lambda data: {}
# EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes
# example code for "My custom warm to hot" color scheme
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 719d5af58..12fbdc3bd 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -211,6 +211,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
params = Column(String(1000))
perm = Column(String(1000))
schema_perm = Column(String(1000))
+ catalog_perm = Column(String(1000), nullable=True, default=None)
is_managed_externally = Column(Boolean, nullable=False, default=False)
external_url = Column(Text, nullable=True)
@@ -1261,9 +1262,20 @@ class SqlaTable(
anchor = f'{name}'
return Markup(anchor)
+ def get_catalog_perm(self) -> str | None:
+ """Returns catalog permission if present, database one otherwise."""
+ return security_manager.get_catalog_perm(
+ self.database.database_name,
+ self.catalog,
+ )
+
def get_schema_perm(self) -> str | None:
"""Returns schema permission if present, database one otherwise."""
- return security_manager.get_schema_perm(self.database, self.schema or None)
+ return security_manager.get_schema_perm(
+ self.database.database_name,
+ self.catalog,
+ self.schema or None,
+ )
def get_perm(self) -> str:
"""
@@ -1282,7 +1294,10 @@ class SqlaTable(
@property
def full_name(self) -> str:
return utils.get_datasource_full_name(
- self.database, self.table_name, schema=self.schema
+ self.database,
+ self.table_name,
+ catalog=self.catalog,
+ schema=self.schema,
)
@property
@@ -1736,7 +1751,10 @@ class SqlaTable(
try:
df = self.database.get_df(
- sql, self.schema or None, mutator=assign_column_label
+ sql,
+ None,
+ self.schema or None,
+ mutator=assign_column_label,
)
except (SupersetErrorException, SupersetErrorsException) as ex:
# SupersetError(s) exception should not be captured; instead, they should
@@ -1870,34 +1888,45 @@ class SqlaTable(
cls,
database: Database,
datasource_name: str,
+ catalog: str | None = None,
schema: str | None = None,
) -> list[SqlaTable]:
- query = (
- db.session.query(cls)
- .filter_by(database_id=database.id)
- .filter_by(table_name=datasource_name)
- )
+ filters = {
+ "database_id": database.id,
+ "table_name": datasource_name,
+ }
+ if catalog:
+ filters["catalog"] = catalog
if schema:
- query = query.filter_by(schema=schema)
- return query.all()
+ filters["schema"] = schema
+
+ return db.session.query(cls).filter_by(**filters).all()
@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
database: Database,
permissions: set[str],
+ catalog_perms: set[str],
schema_perms: set[str],
) -> list[SqlaTable]:
- # TODO(hughhhh): add unit test
+ # remove empty sets from the query, since SQLAlchemy produces horrible SQL for
+ # Model.column._in({}):
+ #
+ # table.column IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1)
+ filters = [
+ method.in_(perms)
+ for method, perms in zip(
+ (SqlaTable.perm, SqlaTable.schema_perm, SqlaTable.catalog_perm),
+ (permissions, schema_perms, catalog_perms),
+ )
+ if perms
+ ]
+
return (
db.session.query(cls)
.filter_by(database_id=database.id)
- .filter(
- or_(
- SqlaTable.perm.in_(permissions),
- SqlaTable.schema_perm.in_(schema_perms),
- )
- )
+ .filter(or_(*filters))
.all()
)
diff --git a/superset/constants.py b/superset/constants.py
index bac83ae55..42cde6115 100644
--- a/superset/constants.py
+++ b/superset/constants.py
@@ -132,6 +132,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = {
"related_objects": "read",
"tables": "read",
"schemas": "read",
+ "catalogs": "read",
"select_star": "read",
"table_metadata": "read",
"table_metadata_deprecated": "read",
diff --git a/superset/databases/api.py b/superset/databases/api.py
index 46a0b8cc6..116ac9f46 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -73,10 +73,12 @@ from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO
from superset.databases.decorators import check_table_access
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
from superset.databases.schemas import (
+ CatalogsResponseSchema,
ColumnarMetadataUploadFilePostSchema,
ColumnarUploadPostSchema,
CSVMetadataUploadFilePostSchema,
CSVUploadPostSchema,
+ database_catalogs_query_schema,
database_schemas_query_schema,
database_tables_query_schema,
DatabaseConnectionSchema,
@@ -146,6 +148,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"table_extra_metadata",
"table_extra_metadata_deprecated",
"select_star",
+ "catalogs",
"schemas",
"test_connection",
"related_objects",
@@ -266,6 +269,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
edit_model_schema = DatabasePutSchema()
apispec_parameter_schemas = {
+ "database_catalogs_query_schema": database_catalogs_query_schema,
"database_schemas_query_schema": database_schemas_query_schema,
"database_tables_query_schema": database_tables_query_schema,
"get_export_ids_schema": get_export_ids_schema,
@@ -273,6 +277,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
openapi_spec_tag = "Database"
openapi_spec_component_schemas = (
+ CatalogsResponseSchema,
ColumnarUploadPostSchema,
CSVUploadPostSchema,
DatabaseConnectionSchema,
@@ -604,6 +609,70 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
)
return self.response_422(message=str(ex))
+ @expose("//catalogs/")
+ @protect()
+ @safe
+ @rison(database_catalogs_query_schema)
+ @statsd_metrics
+ @event_logger.log_this_with_context(
+ action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".catalogs",
+ log_to_statsd=False,
+ )
+ def catalogs(self, pk: int, **kwargs: Any) -> FlaskResponse:
+ """Get all catalogs from a database.
+ ---
+ get:
+ summary: Get all catalogs from a database
+ parameters:
+ - in: path
+ schema:
+ type: integer
+ name: pk
+ description: The database id
+ - in: query
+ name: q
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/database_catalogs_query_schema'
+ responses:
+ 200:
+ description: A List of all catalogs from the database
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/CatalogsResponseSchema"
+ 400:
+ $ref: '#/components/responses/400'
+ 401:
+ $ref: '#/components/responses/401'
+ 404:
+ $ref: '#/components/responses/404'
+ 500:
+ $ref: '#/components/responses/500'
+ """
+ database = DatabaseDAO.find_by_id(pk)
+ if not database:
+ return self.response_404()
+ try:
+ catalogs = database.get_all_catalog_names(
+ cache=database.catalog_cache_enabled,
+ cache_timeout=database.catalog_cache_timeout or None,
+ force=kwargs["rison"].get("force", False),
+ )
+ catalogs = security_manager.get_catalogs_accessible_by_user(
+ database,
+ catalogs,
+ )
+ return self.response(200, result=list(catalogs))
+ except OperationalError:
+ return self.response(
+ 500,
+ message="There was an error connecting to the database",
+ )
+ except SupersetException as ex:
+ return self.response(ex.status, message=ex.message)
+
@expose("//schemas/")
@protect()
@safe
@@ -650,13 +719,19 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
if not database:
return self.response_404()
try:
+ catalog = kwargs["rison"].get("catalog")
schemas = database.get_all_schema_names(
+ catalog=catalog,
cache=database.schema_cache_enabled,
cache_timeout=database.schema_cache_timeout or None,
force=kwargs["rison"].get("force", False),
)
- schemas = security_manager.get_schemas_accessible_by_user(database, schemas)
- return self.response(200, result=schemas)
+ schemas = security_manager.get_schemas_accessible_by_user(
+ database,
+ catalog,
+ schemas,
+ )
+ return self.response(200, result=list(schemas))
except OperationalError:
return self.response(
500, message="There was an error connecting to the database"
@@ -718,10 +793,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500'
"""
force = kwargs["rison"].get("force", False)
+ catalog_name = kwargs["rison"].get("catalog_name")
schema_name = kwargs["rison"].get("schema_name", "")
try:
- command = TablesDatabaseCommand(pk, schema_name, force)
+ command = TablesDatabaseCommand(pk, catalog_name, schema_name, force)
payload = command.run()
return self.response(200, **payload)
except DatabaseNotFoundError:
diff --git a/superset/databases/filters.py b/superset/databases/filters.py
index 33748da4b..420a55fd2 100644
--- a/superset/databases/filters.py
+++ b/superset/databases/filters.py
@@ -28,11 +28,12 @@ from superset.models.core import Database
from superset.views.base import BaseFilter
-def can_access_databases(
- view_menu_name: str,
-) -> set[str]:
+def can_access_databases(view_menu_name: str) -> set[str]:
+ """
+ Return names of databases available in `view_menu_name`.
+ """
return {
- security_manager.unpack_database_and_schema(vm).database
+ vm.split(".")[0][1:-1]
for vm in security_manager.user_view_menu_names(view_menu_name)
}
@@ -56,17 +57,21 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods
# We can proceed with default filtering now
if security_manager.can_access_all_databases():
return query
- database_perms = security_manager.user_view_menu_names("database_access")
- schema_access_databases = can_access_databases("schema_access")
+ database_perms = security_manager.user_view_menu_names("database_access")
+ catalog_access_databases = can_access_databases("catalog_access")
+ schema_access_databases = can_access_databases("schema_access")
datasource_access_databases = can_access_databases("datasource_access")
+ database_names = sorted(
+ catalog_access_databases
+ | schema_access_databases
+ | datasource_access_databases
+ )
return query.filter(
or_(
self.model.perm.in_(database_perms),
- self.model.database_name.in_(
- [*schema_access_databases, *datasource_access_databases]
- ),
+ self.model.database_name.in_(database_names),
)
)
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index 074139857..4318a1b48 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -56,6 +56,14 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri
from superset.utils.core import markdown, parse_ssl_cert
database_schemas_query_schema = {
+ "type": "object",
+ "properties": {
+ "force": {"type": "boolean"},
+ "catalog": {"type": "string"},
+ },
+}
+
+database_catalogs_query_schema = {
"type": "object",
"properties": {"force": {"type": "boolean"}},
}
@@ -65,6 +73,7 @@ database_tables_query_schema = {
"properties": {
"force": {"type": "boolean"},
"schema_name": {"type": "string"},
+ "catalog_name": {"type": "string"},
},
"required": ["schema_name"],
}
@@ -712,6 +721,12 @@ class SchemasResponseSchema(Schema):
)
+class CatalogsResponseSchema(Schema):
+ result = fields.List(
+ fields.String(metadata={"description": "A database catalog name"})
+ )
+
+
class DatabaseTablesResponse(Schema):
extra = fields.Dict(
metadata={"description": "Extra data used to specify column metadata"}
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 3cc131512..4ee4e4dc0 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -404,6 +404,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# Does the DB support catalogs? A catalog here is a group of schemas, and has
# different names depending on the DB: BigQuery calles it a "project", Postgres calls
# it a "database", Trino calls it a "catalog", etc.
+ #
+ # When this is changed to true in a DB engine spec it MUST support the
+ # `get_default_catalog` and `get_catalog_names` methods. In addition, you MUST write
+ # a database migration updating any existing schema permissions.
supports_catalog = False
# Can the catalog be changed on a per-query basis?
@@ -638,10 +642,20 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return driver in cls.drivers
+ @classmethod
+ def get_default_catalog(
+ cls,
+ database: Database, # pylint: disable=unused-argument
+ ) -> str | None:
+ """
+ Return the default catalog for a given database.
+ """
+ return None
+
@classmethod
def get_default_schema(cls, database: Database, catalog: str | None) -> str | None:
"""
- Return the default schema in a given database.
+ Return the default schema for a catalog in a given database.
"""
with database.get_inspector(catalog=catalog) as inspector:
return inspector.default_schema_name
@@ -1412,24 +1426,24 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
inspector: Inspector,
- ) -> list[str]:
+ ) -> set[str]:
"""
Get all catalogs from database.
This needs to be implemented per database, since SQLAlchemy doesn't offer an
abstraction.
"""
- return []
+ return set()
@classmethod
- def get_schema_names(cls, inspector: Inspector) -> list[str]:
+ def get_schema_names(cls, inspector: Inspector) -> set[str]:
"""
Get all schemas from database
:param inspector: SqlAlchemy inspector
:return: All schemas in the database
"""
- return sorted(inspector.get_schema_names())
+ return set(inspector.get_schema_names())
@classmethod
def get_table_names( # pylint: disable=unused-argument
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 8a2612f5b..ca52bd51c 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -127,7 +127,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
allows_hidden_cc_in_orderby = True
- supports_catalog = True
+ supports_catalog = False
"""
https://www.python.org/dev/peps/pep-0249/#arraysize
@@ -464,7 +464,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
cls,
database: Database,
inspector: Inspector,
- ) -> list[str]:
+ ) -> set[str]:
"""
Get all catalogs.
@@ -475,7 +475,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
client = cls._get_client(engine)
projects = client.list_projects()
- return sorted(project.project_id for project in projects)
+ return {project.project_id for project in projects}
@classmethod
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py
index 4346f77d6..aa5399929 100644
--- a/superset/db_engine_specs/clickhouse.py
+++ b/superset/db_engine_specs/clickhouse.py
@@ -278,7 +278,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec):
@classmethod
def get_function_names(cls, database: Database) -> list[str]:
- # pylint: disable=import-outside-toplevel,import-error
+ # pylint: disable=import-outside-toplevel, import-error
from clickhouse_connect.driver.exceptions import ClickHouseError
if cls._function_names:
@@ -340,7 +340,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec):
def validate_parameters(
cls, properties: BasicPropertiesType
) -> list[SupersetError]:
- # pylint: disable=import-outside-toplevel,import-error
+ # pylint: disable=import-outside-toplevel, import-error
from clickhouse_connect.driver import default_port
parameters = properties.get("parameters", {})
diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py
index 1d3ec4e9e..d7d1862aa 100644
--- a/superset/db_engine_specs/impala.py
+++ b/superset/db_engine_specs/impala.py
@@ -74,13 +74,12 @@ class ImpalaEngineSpec(BaseEngineSpec):
return None
@classmethod
- def get_schema_names(cls, inspector: Inspector) -> list[str]:
- schemas = [
+ def get_schema_names(cls, inspector: Inspector) -> set[str]:
+ return {
row[0]
for row in inspector.engine.execute("SHOW SCHEMAS")
if not row[0].startswith("_")
- ]
- return schemas
+ }
@classmethod
def has_implicit_cancel(cls) -> bool:
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index ce87aa1f9..bba2157e0 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -101,8 +101,6 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
engine = ""
engine_name = "PostgreSQL"
- supports_catalog = True
-
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: "DATE_TRUNC('second', {col})",
@@ -199,7 +197,10 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
engine = "postgresql"
engine_aliases = {"postgres"}
+
supports_dynamic_schema = True
+ supports_catalog = True
+ supports_dynamic_catalog = True
default_driver = "psycopg2"
sqlalchemy_uri_placeholder = (
@@ -296,6 +297,29 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
return super().get_default_schema_for_query(database, query)
+ @classmethod
+ def adjust_engine_params(
+ cls,
+ uri: URL,
+ connect_args: dict[str, Any],
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> tuple[URL, dict[str, Any]]:
+ """
+ Set the catalog (database).
+ """
+ if catalog:
+ uri = uri.set(database=catalog)
+
+ return uri, connect_args
+
+ @classmethod
+ def get_default_catalog(cls, database: Database) -> str | None:
+ """
+ Return the default catalog for a given database.
+ """
+ return database.url_object.database
+
@classmethod
def get_prequeries(
cls,
@@ -346,13 +370,13 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
- ) -> list[str]:
+ ) -> set[str]:
"""
Return all catalogs.
In Postgres, a catalog is called a "database".
"""
- return sorted(
+ return {
catalog
for (catalog,) in inspector.bind.execute(
"""
@@ -360,7 +384,7 @@ SELECT datname FROM pg_database
WHERE datistemplate = false;
"""
)
- )
+ }
@classmethod
def get_table_names(
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 34c47eb52..c59b1b1d1 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -648,6 +648,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
engine_name = "Presto"
allows_alias_to_source_column = False
+ supports_catalog = False
+
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__(
@@ -815,11 +817,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
- ) -> list[str]:
+ ) -> set[str]:
"""
Get all catalogs.
"""
- return [catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")]
+ return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
@classmethod
def _create_column_info(
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index 83d382cda..9a82cfcca 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -85,7 +85,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
sqlalchemy_uri_placeholder = "snowflake://"
supports_dynamic_schema = True
- supports_catalog = True
+ supports_catalog = False
_time_grain_expressions = {
None: "{col}",
@@ -174,18 +174,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
cls,
database: "Database",
inspector: Inspector,
- ) -> list[str]:
+ ) -> set[str]:
"""
Return all catalogs.
In Snowflake, a catalog is called a "database".
"""
- return sorted(
+ return {
catalog
for (catalog,) in inspector.bind.execute(
"SELECT DATABASE_NAME from information_schema.databases"
)
- )
+ }
@classmethod
def epoch_to_dttm(cls) -> str:
diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py
index 2d8444cc9..fd697aea8 100644
--- a/superset/extensions/metadb.py
+++ b/superset/extensions/metadb.py
@@ -270,11 +270,6 @@ class SupersetShillelaghAdapter(Adapter):
self.schema = parts.pop(-1) if parts else None
self.catalog = parts.pop(-1) if parts else None
- if self.catalog:
- # TODO (betodealmeida): when SIP-95 is implemented we should check to see if
- # the database has multi-catalog enabled, and if so, give access.
- raise NotImplementedError("Catalogs are not currently supported")
-
# If the table has a single integer primary key we use that as the row ID in order
# to perform updates and deletes. Otherwise we can only do inserts and selects.
self._rowid: str | None = None
diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py
new file mode 100644
index 000000000..5d01ecfbf
--- /dev/null
+++ b/superset/migrations/shared/catalogs.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+
+from alembic import op
+
+from superset import db, security_manager
+from superset.daos.database import DatabaseDAO
+from superset.models.core import Database
+
+logger = logging.getLogger(__name__)
+
+
+def upgrade_schema_perms(engine: str | None = None) -> None:
+ """
+ Update schema permissions to include the catalog part.
+
+ Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
+ introduction of catalogs, any existing permissions need to be renamed to include the
+ catalog: `[db].[catalog].[schema]`.
+ """
+ bind = op.get_bind()
+ session = db.Session(bind=bind)
+ for database in session.query(Database).all():
+ db_engine_spec = database.db_engine_spec
+ if (
+ engine and db_engine_spec.engine != engine
+ ) or not db_engine_spec.supports_catalog:
+ continue
+
+ catalog = database.get_default_catalog()
+ ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+ for schema in database.get_all_schema_names(
+ catalog=catalog,
+ cache=False,
+ ssh_tunnel=ssh_tunnel,
+ ):
+ perm = security_manager.get_schema_perm(
+ database.database_name,
+ None,
+ schema,
+ )
+ existing_pvm = security_manager.find_permission_view_menu(
+ "schema_access",
+ perm,
+ )
+ if existing_pvm:
+ existing_pvm.view_menu.name = security_manager.get_schema_perm(
+ database.database_name,
+ catalog,
+ schema,
+ )
+
+ session.commit()
+
+
+def downgrade_schema_perms(engine: str | None = None) -> None:
+ """
+ Update schema permissions to not have the catalog part.
+
+ Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
+ introduction of catalogs, any existing permissions need to be renamed to include the
+ catalog: `[db].[catalog].[schema]`.
+
+ This helped function reverts the process.
+ """
+ bind = op.get_bind()
+ session = db.Session(bind=bind)
+ for database in session.query(Database).all():
+ db_engine_spec = database.db_engine_spec
+ if (
+ engine and db_engine_spec.engine != engine
+ ) or not db_engine_spec.supports_catalog:
+ continue
+
+ catalog = database.get_default_catalog()
+ ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
+ for schema in database.get_all_schema_names(
+ catalog=catalog,
+ cache=False,
+ ssh_tunnel=ssh_tunnel,
+ ):
+ perm = security_manager.get_schema_perm(
+ database.database_name,
+ catalog,
+ schema,
+ )
+ existing_pvm = security_manager.find_permission_view_menu(
+ "schema_access",
+ perm,
+ )
+ if existing_pvm:
+ existing_pvm.view_menu.name = security_manager.get_schema_perm(
+ database.database_name,
+ None,
+ schema,
+ )
+
+ session.commit()
diff --git a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py
new file mode 100644
index 000000000..17b33e1d0
--- /dev/null
+++ b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add catalog_perm to tables
+
+Revision ID: 58d051681a3b
+Revises: 4a33124c18ad
+Create Date: 2024-05-01 10:52:31.458433
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+from superset.migrations.shared.catalogs import (
+ downgrade_schema_perms,
+ upgrade_schema_perms,
+)
+
+# revision identifiers, used by Alembic.
+revision = "58d051681a3b"
+down_revision = "4a33124c18ad"
+
+
+def upgrade():
+ op.add_column(
+ "tables",
+ sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
+ )
+ op.add_column(
+ "slices",
+ sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
+ )
+ upgrade_schema_perms(engine="postgresql")
+
+
+def downgrade():
+ op.drop_column("slices", "catalog_perm")
+ op.drop_column("tables", "catalog_perm")
+ downgrade_schema_perms(engine="postgresql")
diff --git a/superset/models/core.py b/superset/models/core.py
index 9a4a1de40..fe486bf2b 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -78,7 +78,7 @@ from superset.sql_parse import Table
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
from superset.utils import cache as cache_util, core as utils
from superset.utils.backports import StrEnum
-from superset.utils.core import get_username
+from superset.utils.core import DatasourceName, get_username
from superset.utils.oauth2 import get_oauth2_access_token
config = app.config
@@ -313,6 +313,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
def metadata_cache_timeout(self) -> dict[str, Any]:
return self.get_extra().get("metadata_cache_timeout", {})
+ @property
+ def catalog_cache_enabled(self) -> bool:
+ return "catalog_cache_timeout" in self.metadata_cache_timeout
+
+ @property
+ def catalog_cache_timeout(self) -> int | None:
+ return self.metadata_cache_timeout.get("catalog_cache_timeout")
+
@property
def schema_cache_enabled(self) -> bool:
return "schema_cache_timeout" in self.metadata_cache_timeout
@@ -549,6 +557,18 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
yield conn
+ def get_default_catalog(self) -> str | None:
+ """
+ Return the default configured catalog for the database.
+ """
+ return self.db_engine_spec.get_default_catalog(self)
+
+ def get_default_schema(self, catalog: str | None) -> str | None:
+ """
+ Return the default schema for the database.
+ """
+ return self.db_engine_spec.get_default_schema(self, catalog)
+
def get_default_schema_for_query(self, query: Query) -> str | None:
"""
Return the default schema for a given query.
@@ -706,19 +726,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
key="db:{self.id}:schema:{schema}:table_list",
cache=cache_manager.cache,
)
- def get_all_table_names_in_schema( # pylint: disable=unused-argument
+ def get_all_table_names_in_schema(
self,
catalog: str | None,
schema: str,
- cache: bool = False,
- cache_timeout: int | None = None,
- force: bool = False,
- ) -> set[tuple[str, str]]:
+ ) -> set[DatasourceName]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
+ :param catalog: optional catalog name
:param schema: schema name
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
@@ -728,7 +746,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
try:
with self.get_inspector(catalog=catalog, schema=schema) as inspector:
return {
- (table, schema)
+ DatasourceName(table, schema, catalog)
for table in self.db_engine_spec.get_table_names(
database=self,
inspector=inspector,
@@ -742,19 +760,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
key="db:{self.id}:schema:{schema}:view_list",
cache=cache_manager.cache,
)
- def get_all_view_names_in_schema( # pylint: disable=unused-argument
+ def get_all_view_names_in_schema(
self,
catalog: str | None,
schema: str,
- cache: bool = False,
- cache_timeout: int | None = None,
- force: bool = False,
- ) -> set[tuple[str, str]]:
+ ) -> set[DatasourceName]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
+ :param catalog: optional catalog name
:param schema: schema name
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
@@ -764,7 +780,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
try:
with self.get_inspector(catalog=catalog, schema=schema) as inspector:
return {
- (view, schema)
+ DatasourceName(view, schema, catalog)
for view in self.db_engine_spec.get_view_names(
database=self,
inspector=inspector,
@@ -792,22 +808,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
key="db:{self.id}:schema_list",
cache=cache_manager.cache,
)
- def get_all_schema_names( # pylint: disable=unused-argument
+ def get_all_schema_names(
self,
+ *,
catalog: str | None = None,
- cache: bool = False,
- cache_timeout: int | None = None,
- force: bool = False,
ssh_tunnel: SSHTunnel | None = None,
- ) -> list[str]:
- """Parameters need to be passed as keyword arguments.
+ ) -> set[str]:
+ """
+ Return the schemas in a given database
- For unused parameters, they are referenced in
- cache_util.memoized_func decorator.
-
- :param cache: whether cache is enabled for the function
- :param cache_timeout: timeout in seconds for the cache
- :param force: whether to force refresh the cache
+ :param catalog: override default catalog
+ :param ssh_tunnel: SSH tunnel information needed to establish a connection
:return: schema list
"""
try:
@@ -819,6 +830,27 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
+ @cache_util.memoized_func(
+ key="db:{self.id}:catalog_list",
+ cache=cache_manager.cache,
+ )
+ def get_all_catalog_names(
+ self,
+ *,
+ ssh_tunnel: SSHTunnel | None = None,
+ ) -> set[str]:
+ """
+ Return the catalogs in a given database
+
+ :param ssh_tunnel: SSH tunnel information needed to establish a connection
+ :return: catalog list
+ """
+ try:
+ with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector:
+ return self.db_engine_spec.get_catalog_names(self, inspector)
+ except Exception as ex:
+ raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
+
@property
def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]:
url = make_url_safe(self.sqlalchemy_uri_decrypted)
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index f77b3ab4d..100391086 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -209,7 +209,10 @@ class ImportExportMixin:
"""Get a mapping of foreign name to the local name of foreign keys"""
parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
if parent_rel:
- return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs} # noqa: E741
+ return {
+ local.name: remote.name
+ for (local, remote) in parent_rel.local_remote_pairs
+ }
return {}
@classmethod
@@ -772,6 +775,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def database(self) -> "Database":
raise NotImplementedError()
+ @property
+ def catalog(self) -> str:
+ raise NotImplementedError()
+
@property
def schema(self) -> str:
raise NotImplementedError()
@@ -1025,7 +1032,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return df
try:
- df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
+ df = self.database.get_df(
+ sql,
+ self.catalog,
+ self.schema,
+ mutator=assign_column_label,
+ )
except Exception as ex: # pylint: disable=broad-except
df = pd.DataFrame()
status = QueryStatus.FAILED
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 197b09a9e..2a0734b10 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -84,6 +84,7 @@ class Slice( # pylint: disable=too-many-public-methods
cache_timeout = Column(Integer)
perm = Column(String(1000))
schema_perm = Column(String(1000))
+ catalog_perm = Column(String(1000), nullable=True, default=None)
# the last time a user has saved the chart, changed_on is referencing
# when the database row was last written
last_saved_at = Column(DateTime, nullable=True)
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 790301726..d28ed1789 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -22,7 +22,7 @@ import logging
import re
import time
from collections import defaultdict
-from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
+from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING
from flask import current_app, Flask, g, Request
from flask_appbuilder import Model
@@ -67,7 +67,7 @@ from superset.security.guest_token import (
GuestTokenUser,
GuestUser,
)
-from superset.sql_parse import extract_tables_from_jinja_sql
+from superset.sql_parse import extract_tables_from_jinja_sql, Table
from superset.superset_typing import Metric
from superset.utils.core import (
DatasourceName,
@@ -89,7 +89,6 @@ if TYPE_CHECKING:
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query
- from superset.sql_parse import Table
from superset.viz import BaseViz
logger = logging.getLogger(__name__)
@@ -97,8 +96,9 @@ logger = logging.getLogger(__name__)
DATABASE_PERM_REGEX = re.compile(r"^\[.+\]\.\(id\:(?P\d+)\)$")
-class DatabaseAndSchema(NamedTuple):
+class DatabaseCatalogSchema(NamedTuple):
database: str
+ catalog: Optional[str]
schema: str
@@ -346,17 +346,51 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return self.get_guest_user_from_request(request)
return None
+ def get_catalog_perm(
+ self,
+ database: str,
+ catalog: Optional[str] = None,
+ ) -> Optional[str]:
+ """
+ Return the database specific catalog permission.
+
+ :param database: The Superset database or database name
+ :param catalog: The database catalog name
+ :return: The database specific schema permission
+ """
+ if catalog is None:
+ return None
+
+ return f"[{database}].[{catalog}]"
+
def get_schema_perm(
- self, database: Union["Database", str], schema: Optional[str] = None
+ self,
+ database: str,
+ catalog: Optional[str] = None,
+ schema: Optional[str] = None,
) -> Optional[str]:
"""
Return the database specific schema permission.
- :param database: The Superset database or database name
- :param schema: The Superset schema name
+ Catalogs were added in SIP-95, and not all databases support them. Because of
+ this, the format used for permissions is different depending on whether a
+ catalog is passed or not:
+
+ [database].[schema]
+ [database].[catalog].[schema]
+
+ :param database: The database name
+ :param catalog: The database catalog name
+ :param schema: The database schema name
:return: The database specific schema permission
"""
- return f"[{database}].[{schema}]" if schema else None
+ if schema is None:
+ return None
+
+ if catalog:
+ return f"[{database}].[{catalog}].[{schema}]"
+
+ return f"[{database}].[{schema}]"
@staticmethod
def get_database_perm(database_id: int, database_name: str) -> Optional[str]:
@@ -370,13 +404,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
) -> Optional[str]:
return f"[{database_name}].[{dataset_name}](id:{dataset_id})"
- def unpack_database_and_schema(self, schema_permission: str) -> DatabaseAndSchema:
- # [database_name].[schema|table]
-
- schema_name = schema_permission.split(".")[1][1:-1]
- database_name = schema_permission.split(".")[0][1:-1]
- return DatabaseAndSchema(database_name, schema_name)
-
def can_access(self, permission_name: str, view_name: str) -> bool:
"""
Return True if the user can access the FAB permission/view, False otherwise.
@@ -436,6 +463,17 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
or self.can_access("database_access", database.perm) # type: ignore
)
+ def can_access_catalog(self, database: "Database", catalog: str) -> bool:
+ """
+ Return if the user can access the specified catalog.
+ """
+ catalog_perm = self.get_catalog_perm(database.database_name, catalog)
+ return bool(
+ self.can_access_all_datasources()
+ or self.can_access_database(database)
+ or (catalog_perm and self.can_access("catalog_access", catalog_perm))
+ )
+
def can_access_schema(self, datasource: "BaseDatasource") -> bool:
"""
Return True if the user can access the schema associated with specified
@@ -448,6 +486,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return (
self.can_access_all_datasources()
or self.can_access_database(datasource.database)
+ or self.can_access_catalog(datasource.database, datasource.catalog)
or self.can_access("schema_access", datasource.schema_perm or "")
)
@@ -706,55 +745,150 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
]
def get_schemas_accessible_by_user(
- self, database: "Database", schemas: list[str], hierarchical: bool = True
- ) -> list[str]:
+ self,
+ database: "Database",
+ catalog: Optional[str],
+ schemas: set[str],
+ hierarchical: bool = True,
+ ) -> set[str]:
"""
- Return the list of SQL schemas accessible by the user.
+ Returned a filtered list of the schemas accessible by the user.
+
+ If not catalog is specified, the default catalog is used.
:param database: The SQL database
- :param schemas: The list of eligible SQL schemas
+ :param catalog: An optional database catalog
+ :param schemas: A set of candidate schemas
:param hierarchical: Whether to check using the hierarchical permission logic
- :returns: The list of accessible SQL schemas
+ :returns: The set of accessible database schemas
"""
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable
- if hierarchical and self.can_access_database(database):
+ if hierarchical and (
+ self.can_access_database(database)
+ or (catalog and self.can_access_catalog(database, catalog))
+ ):
return schemas
# schema_access
- accessible_schemas = {
- self.unpack_database_and_schema(s).schema
- for s in self.user_view_menu_names("schema_access")
- if s.startswith(f"[{database}].")
- }
+ accessible_schemas: set[str] = set()
+ schema_access = self.user_view_menu_names("schema_access")
+ default_catalog = database.get_default_catalog()
+ default_schema = database.get_default_schema(default_catalog)
+
+ for perm in schema_access:
+ parts = [part[1:-1] for part in perm.split(".")]
+
+ if parts[0] != database.database_name:
+ continue
+
+ # [database].[schema] matches when no catalog is specified, or when the user
+ # specifies the default catalog
+ if len(parts) == 2 and (catalog is None or catalog == default_catalog):
+ accessible_schemas.add(parts[1])
+
+ # [database].[catalog].[schema] matches when the catalog is equal to the
+ # requested catalog or, when no catalog specified, it's equal to the default
+ # catalog.
+ elif len(parts) == 3 and parts[1] == (catalog or default_catalog):
+ accessible_schemas.add(parts[2])
# datasource_access
if perms := self.user_view_menu_names("datasource_access"):
tables = (
self.get_session.query(SqlaTable.schema)
.filter(SqlaTable.database_id == database.id)
- .filter(SqlaTable.schema.isnot(None))
- .filter(SqlaTable.schema != "")
.filter(or_(SqlaTable.perm.in_(perms)))
.distinct()
)
- accessible_schemas.update([table.schema for table in tables])
+ accessible_schemas.update(
+ {
+ table.schema or default_schema # type: ignore
+ for table in tables
+ if (table.schema or default_schema)
+ }
+ )
- return [s for s in schemas if s in accessible_schemas]
+ return schemas & accessible_schemas
+
+ def get_catalogs_accessible_by_user(
+ self,
+ database: "Database",
+ catalogs: set[str],
+ hierarchical: bool = True,
+ ) -> set[str]:
+ """
+ Returned a filtered list of the catalogs accessible by the user.
+
+ :param database: The SQL database
+ :param catalogs: A set of candidate catalogs
+ :param hierarchical: Whether to check using the hierarchical permission logic
+ :returns: The set of accessible database catalogs
+ """
+ # pylint: disable=import-outside-toplevel
+ from superset.connectors.sqla.models import SqlaTable
+
+ if hierarchical and self.can_access_database(database):
+ return catalogs
+
+ # catalog access
+ accessible_catalogs: set[str] = set()
+ catalog_access = self.user_view_menu_names("catalog_access")
+ default_catalog = database.get_default_catalog()
+
+ for perm in catalog_access:
+ parts = [part[1:-1] for part in perm.split(".")]
+ if parts[0] == database.database_name:
+ accessible_catalogs.add(parts[1])
+
+ # schema access
+ schema_access = self.user_view_menu_names("schema_access")
+ for perm in schema_access:
+ parts = [part[1:-1] for part in perm.split(".")]
+
+ if parts[0] != database.database_name:
+ continue
+ if len(parts) == 2 and default_catalog:
+ accessible_catalogs.add(default_catalog)
+ elif len(parts) == 3:
+ accessible_catalogs.add(parts[2])
+
+ # datasource_access
+ if perms := self.user_view_menu_names("datasource_access"):
+ tables = (
+ self.get_session.query(SqlaTable.schema)
+ .filter(SqlaTable.database_id == database.id)
+ .filter(or_(SqlaTable.perm.in_(perms)))
+ .distinct()
+ )
+ accessible_catalogs.update(
+ {
+ table.catalog or default_catalog # type: ignore
+ for table in tables
+ if (table.catalog or default_catalog)
+ }
+ )
+
+ return catalogs & accessible_catalogs
def get_datasources_accessible_by_user( # pylint: disable=invalid-name
self,
database: "Database",
datasource_names: list[DatasourceName],
+ catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> list[DatasourceName]:
"""
- Return the list of SQL tables accessible by the user.
+ Filter list of SQL tables to the ones accessible by the user.
+
+ When catalog and/or schema are specified, it's assumed that all datasources in
+ `datasource_names` are in the given catalog/schema.
:param database: The SQL database
:param datasource_names: The list of eligible SQL tables w/ schema
+ :param catalog: The fallback SQL catalog if not present in the table name
:param schema: The fallback SQL schema if not present in the table name
:returns: The list of accessible SQL tables w/ schema
"""
@@ -764,22 +898,34 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if self.can_access_database(database):
return datasource_names
+ if catalog:
+ catalog_perm = self.get_catalog_perm(database.database_name, catalog)
+ if catalog_perm and self.can_access("catalog_access", catalog_perm):
+ return datasource_names
+
if schema:
- schema_perm = self.get_schema_perm(database, schema)
+ schema_perm = self.get_schema_perm(database, catalog, schema)
if schema_perm and self.can_access("schema_access", schema_perm):
return datasource_names
user_perms = self.user_view_menu_names("datasource_access")
+ catalog_perms = self.user_view_menu_names("catalog_access")
schema_perms = self.user_view_menu_names("schema_access")
- user_datasources = SqlaTable.query_datasources_by_permissions(
- database, user_perms, schema_perms
- )
- if schema:
- names = {d.table_name for d in user_datasources if d.schema == schema}
- return [d for d in datasource_names if d.table in names]
+ user_datasources = {
+ DatasourceName(table.table_name, table.schema, table.catalog)
+ for table in SqlaTable.query_datasources_by_permissions(
+ database,
+ user_perms,
+ catalog_perms,
+ schema_perms,
+ )
+ }
- full_names = {d.full_name for d in user_datasources}
- return [d for d in datasource_names if f"[{database}].[{d}]" in full_names]
+ return [
+ datasource
+ for datasource in datasource_names
+ if datasource in user_datasources
+ ]
def merge_perm(self, permission_name: str, view_menu_name: str) -> None:
"""
@@ -844,6 +990,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
for datasource in datasources:
merge_pv("datasource_access", datasource.get_perm())
merge_pv("schema_access", datasource.get_schema_perm())
+ merge_pv("catalog_access", datasource.get_catalog_perm())
logger.info("Creating missing database permissions.")
databases = self.get_session.query(models.Database).all()
@@ -1212,7 +1359,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
self.get_session.query(self.permissionview_model)
.join(self.permission_model)
.join(self.viewmenu_model)
- .filter(self.permission_model.name == "schema_access")
+ .filter(
+ or_(
+ self.permission_model.name == "schema_access",
+ self.permission_model.name == "catalog_access",
+ )
+ )
.filter(self.viewmenu_model.name.like(f"[{database_name}].[%]"))
.all()
)
@@ -1399,18 +1551,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
.values(perm=dataset_perm)
)
+ # update catalog and schema perms
+ values: dict[str, Optional[str]] = {}
+
if target.schema:
dataset_schema_perm = self.get_schema_perm(
- database.database_name, target.schema
+ database.database_name,
+ target.catalog,
+ target.schema,
)
self._insert_pvm_on_sqla_event(
- mapper, connection, "schema_access", dataset_schema_perm
+ mapper,
+ connection,
+ "schema_access",
+ dataset_schema_perm,
)
target.schema_perm = dataset_schema_perm
+ values["schema_perm"] = dataset_schema_perm
+
+ if target.catalog:
+ dataset_catalog_perm = self.get_catalog_perm(
+ database.database_name,
+ target.catalog,
+ )
+ self._insert_pvm_on_sqla_event(
+ mapper,
+ connection,
+ "catalog_access",
+ dataset_catalog_perm,
+ )
+ target.catalog_perm = dataset_catalog_perm
+ values["catalog_perm"] = dataset_catalog_perm
+
+ if values:
connection.execute(
dataset_table.update()
.where(dataset_table.c.id == target.id)
- .values(schema_perm=dataset_schema_perm)
+ .values(**values)
)
def dataset_after_delete(
@@ -1467,6 +1644,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
table.select().where(table.c.id == target.id)
).one()
current_db_id = current_dataset.database_id
+ current_catalog = current_dataset.catalog
current_schema = current_dataset.schema
current_table_name = current_dataset.table_name
@@ -1479,14 +1657,21 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
mapper, connection, target.perm, new_dataset_vm_name, target
)
- # Updates schema permissions
- new_dataset_schema_name = self.get_schema_perm(
- target.database.database_name, target.schema
+ # Updates catalog/schema permissions
+ dataset_catalog_name = self.get_catalog_perm(
+ target.database.database_name,
+ target.catalog,
)
- self._update_dataset_schema_perm(
+ dataset_schema_name = self.get_schema_perm(
+ target.database.database_name,
+ target.catalog,
+ target.schema,
+ )
+ self._update_dataset_catalog_schema_perm(
mapper,
connection,
- new_dataset_schema_name,
+ dataset_catalog_name,
+ dataset_schema_name,
target,
)
@@ -1502,23 +1687,32 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
mapper, connection, old_dataset_vm_name, new_dataset_vm_name, target
)
- # When schema changes
- if current_schema != target.schema:
- new_dataset_schema_name = self.get_schema_perm(
- target.database.database_name, target.schema
+ # When catalog/schema change
+ if current_catalog != target.catalog or current_schema != target.schema:
+ dataset_catalog_name = self.get_catalog_perm(
+ target.database.database_name,
+ target.catalog,
)
- self._update_dataset_schema_perm(
+ dataset_schema_name = self.get_schema_perm(
+ target.database.database_name,
+ target.catalog,
+ target.schema,
+ )
+ self._update_dataset_catalog_schema_perm(
mapper,
connection,
- new_dataset_schema_name,
+ dataset_catalog_name,
+ dataset_schema_name,
target,
)
- def _update_dataset_schema_perm(
+ # pylint: disable=invalid-name, too-many-arguments
+ def _update_dataset_catalog_schema_perm(
self,
mapper: Mapper,
connection: Connection,
- new_schema_permission_name: Optional[str],
+ catalog_permission_name: Optional[str],
+ schema_permission_name: Optional[str],
target: "SqlaTable",
) -> None:
"""
@@ -1530,11 +1724,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
:param mapper: The SQLA event mapper
:param connection: The SQLA connection
- :param new_schema_permission_name: The new schema permission name that changed
+ :param catalog_permission_name: The new catalog permission name that changed
+ :param schema_permission_name: The new schema permission name that changed
:param target: Dataset that was updated
:return:
"""
- logger.info("Updating schema perm, new: %s", new_schema_permission_name)
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
SqlaTable,
)
@@ -1545,18 +1739,30 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
sqlatable_table = SqlaTable.__table__ # pylint: disable=no-member
chart_table = Slice.__table__ # pylint: disable=no-member
- # insert new schema PVM if it does not exist
+ # insert new PVMs if they don't not exist
self._insert_pvm_on_sqla_event(
- mapper, connection, "schema_access", new_schema_permission_name
+ mapper,
+ connection,
+ "catalog_access",
+ catalog_permission_name,
+ )
+ self._insert_pvm_on_sqla_event(
+ mapper,
+ connection,
+ "schema_access",
+ schema_permission_name,
)
- # Update dataset (SqlaTable schema_perm field)
+ # Update dataset
connection.execute(
sqlatable_table.update()
.where(
sqlatable_table.c.id == target.id,
)
- .values(schema_perm=new_schema_permission_name)
+ .values(
+ catalog_perm=catalog_permission_name,
+ schema_perm=schema_permission_name,
+ )
)
# Update charts (Slice schema_perm field)
@@ -1566,7 +1772,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
chart_table.c.datasource_id == target.id,
chart_table.c.datasource_type == DatasourceType.TABLE,
)
- .values(schema_perm=new_schema_permission_name)
+ .values(
+ catalog_perm=catalog_permission_name,
+ schema_perm=schema_permission_name,
+ )
)
def _update_dataset_perm( # pylint: disable=too-many-arguments
@@ -1923,7 +2132,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
table: Optional["Table"] = None,
viz: Optional["BaseViz"] = None,
sql: Optional[str] = None,
- catalog: Optional[str] = None, # pylint: disable=unused-argument
+ catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> None:
"""
@@ -1947,7 +2156,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query
- from superset.sql_parse import Table
from superset.utils.core import shortid
if sql and database:
@@ -1955,6 +2163,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
database=database,
sql=sql,
schema=schema,
+ catalog=catalog,
client_id=shortid()[:10],
user_id=get_user_id(),
)
@@ -1970,9 +2179,23 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return
if query:
+ # Getting the default schema for a query is hard. Users can select the
+ # schema in SQL Lab, but there's no guarantee that the query actually
+ # will run in that schema. Each DB engine spec needs to implement the
+ # necessary logic to enforce that the query runs in the selected schema.
+ # If the DB engine spec doesn't implement the logic the schema is read
+ # from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
+ # inspector to read it.
default_schema = database.get_default_schema_for_query(query)
+ # Determining the default catalog is much easier, because DB engine
+ # specs need explicit support for catalogs.
+ default_catalog = database.get_default_catalog()
tables = {
- Table(table_.table, table_.schema or default_schema)
+ Table(
+ table_.table,
+ table_.schema or default_schema,
+ table_.catalog or default_catalog,
+ )
for table_ in extract_tables_from_jinja_sql(query.sql, database)
}
elif table:
@@ -1981,21 +2204,36 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
denied = set()
for table_ in tables:
- schema_perm = self.get_schema_perm(database, schema=table_.schema)
+ catalog_perm = self.get_catalog_perm(
+ database.database_name,
+ table_.catalog,
+ )
+ if catalog_perm and self.can_access("catalog_access", catalog_perm):
+ continue
- if not (schema_perm and self.can_access("schema_access", schema_perm)):
- datasources = SqlaTable.query_datasources_by_name(
- database, table_.table, schema=table_.schema
- )
+ schema_perm = self.get_schema_perm(
+ database,
+ table_.catalog,
+ table_.schema,
+ )
+ if schema_perm and self.can_access("schema_access", schema_perm):
+ continue
- # Access to any datasource is suffice.
- for datasource_ in datasources:
- if self.can_access(
- "datasource_access", datasource_.perm
- ) or self.is_owner(datasource_):
- break
- else:
- denied.add(table_)
+ datasources = SqlaTable.query_datasources_by_name(
+ database,
+ table_.table,
+ schema=table_.schema,
+ catalog=table_.catalog,
+ )
+ for datasource_ in datasources:
+ if self.can_access(
+ "datasource_access",
+ datasource_.perm,
+ ) or self.is_owner(datasource_):
+ # access to any datasource is sufficient
+ break
+ else:
+ denied.add(table_)
if denied:
raise SupersetSecurityException(
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index 48e283e7c..00216fc4b 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -119,7 +119,11 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[...,
def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapped_f(*args: Any, **kwargs: Any) -> Any:
- if not kwargs.get("cache", True):
+ should_cache = kwargs.pop("cache", True)
+ force = kwargs.pop("force", False)
+ cache_timeout = kwargs.pop("cache_timeout", 0)
+
+ if not should_cache:
return f(*args, **kwargs)
# format the key using args/kwargs passed to the decorated function
@@ -129,10 +133,10 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[...,
cache_key = key.format(**bound_args.arguments)
obj = cache.get(cache_key)
- if not kwargs.get("force") and obj is not None:
+ if not force and obj is not None:
return obj
obj = f(*args, **kwargs)
- cache.set(cache_key, obj, timeout=kwargs.get("cache_timeout", 0))
+ cache.set(cache_key, obj, timeout=cache_timeout)
return obj
return wrapped_f
diff --git a/superset/utils/core.py b/superset/utils/core.py
index f02b00443..514cbac75 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -679,11 +679,13 @@ def generic_find_uq_constraint_name(
def get_datasource_full_name(
- database_name: str, datasource_name: str, schema: str | None = None
+ database_name: str,
+ datasource_name: str,
+ catalog: str | None = None,
+ schema: str | None = None,
) -> str:
- if not schema:
- return f"[{database_name}].[{datasource_name}]"
- return f"[{database_name}].[{schema}].[{datasource_name}]"
+ parts = [database_name, catalog, schema, datasource_name]
+ return ".".join([f"[{part}]" for part in parts if part])
def validate_json(obj: bytes | bytearray | str) -> None:
@@ -1051,8 +1053,8 @@ def merge_extra_form_data(form_data: dict[str, Any]) -> None:
"adhoc_filters", []
)
adhoc_filters.extend(
- {"isExtra": True, **fltr} # type: ignore
- for fltr in append_adhoc_filters
+ {"isExtra": True, **adhoc_filter} # type: ignore
+ for adhoc_filter in append_adhoc_filters
)
if append_filters:
for key, value in form_data.items():
@@ -1502,6 +1504,7 @@ def shortid() -> str:
class DatasourceName(NamedTuple):
table: str
schema: str
+ catalog: str | None = None
def get_stacktrace() -> str | None:
diff --git a/superset/utils/filters.py b/superset/utils/filters.py
index 88154a40b..8c4a07994 100644
--- a/superset/utils/filters.py
+++ b/superset/utils/filters.py
@@ -32,10 +32,12 @@ def get_dataset_access_filters(
database_ids = security_manager.get_accessible_databases()
perms = security_manager.user_view_menu_names("datasource_access")
schema_perms = security_manager.user_view_menu_names("schema_access")
+ catalog_perms = security_manager.user_view_menu_names("catalog_access")
return or_(
Database.id.in_(database_ids),
base_model.perm.in_(perms),
+ base_model.catalog_perm.in_(catalog_perms),
base_model.schema_perm.in_(schema_perms),
*args,
)
diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py
index c6e799e6d..0d104aad5 100644
--- a/superset/views/database/mixins.py
+++ b/superset/views/database/mixins.py
@@ -211,11 +211,29 @@ class DatabaseMixin:
utils.parse_ssl_cert(database.server_cert)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
security_manager.add_permission_view_menu("database_access", database.perm)
- # adding a new database we always want to force refresh schema list
- for schema in database.get_all_schema_names():
- security_manager.add_permission_view_menu(
- "schema_access", security_manager.get_schema_perm(database, schema)
- )
+
+ # add catalog/schema permissions
+ if database.db_engine_spec.supports_catalog:
+ catalogs = database.get_all_catalog_names()
+ for catalog in catalogs:
+ security_manager.add_permission_view_menu(
+ "catalog_access",
+ security_manager.get_catalog_perm(database.database_name, catalog),
+ )
+ else:
+ # add a dummy catalog for DBs that don't support them
+ catalogs = [None]
+
+ for catalog in catalogs:
+ for schema in database.get_all_schema_names(catalog=catalog):
+ security_manager.add_permission_view_menu(
+ "schema_access",
+ security_manager.get_schema_perm(
+ database.database_name,
+ catalog,
+ schema,
+ ),
+ )
def pre_add(self, database: Database) -> None:
self._pre_add_update(database)
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index 9e0ee97b1..830625a15 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -36,7 +36,6 @@ from sqlalchemy.exc import DBAPIError
from sqlalchemy.sql import func
from superset import db, security_manager
-from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError # noqa: F401
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe # noqa: F401
@@ -296,14 +295,14 @@ class TestDatabaseApi(SupersetTestCase):
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_ssh_tunnel(
self,
- mock_test_connection_database_command_run,
- mock_create_is_feature_enabled,
mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_create_is_feature_enabled,
+ mock_test_connection_database_command_run,
):
"""
Database API: Test create with SSH Tunnel
@@ -344,14 +343,14 @@ class TestDatabaseApi(SupersetTestCase):
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_missing_port_raises_error(
self,
- mock_test_connection_database_command_run,
- mock_create_is_feature_enabled,
mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_create_is_feature_enabled,
+ mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
@@ -397,15 +396,15 @@ class TestDatabaseApi(SupersetTestCase):
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_with_ssh_tunnel(
self,
- mock_test_connection_database_command_run,
- mock_create_is_feature_enabled,
- mock_update_is_feature_enabled,
mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_update_is_feature_enabled,
+ mock_create_is_feature_enabled,
+ mock_test_connection_database_command_run,
):
"""
Database API: Test update Database with SSH Tunnel
@@ -458,15 +457,15 @@ class TestDatabaseApi(SupersetTestCase):
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_with_missing_port_raises_error(
self,
- mock_test_connection_database_command_run,
- mock_create_is_feature_enabled,
- mock_update_is_feature_enabled,
mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_update_is_feature_enabled,
+ mock_create_is_feature_enabled,
+ mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
@@ -523,16 +522,16 @@ class TestDatabaseApi(SupersetTestCase):
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
def test_delete_ssh_tunnel(
self,
- mock_test_connection_database_command_run,
- mock_create_is_feature_enabled,
- mock_update_is_feature_enabled,
- mock_delete_is_feature_enabled,
mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_delete_is_feature_enabled,
+ mock_update_is_feature_enabled,
+ mock_create_is_feature_enabled,
+ mock_test_connection_database_command_run,
):
"""
Database API: Test deleting a SSH tunnel via Database update
@@ -606,15 +605,15 @@ class TestDatabaseApi(SupersetTestCase):
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_ssh_tunnel_via_database_api(
self,
- mock_test_connection_database_command_run,
- mock_create_is_feature_enabled,
- mock_update_is_feature_enabled,
mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_update_is_feature_enabled,
+ mock_create_is_feature_enabled,
+ mock_test_connection_database_command_run,
):
"""
Database API: Test update SSH Tunnel via Database API
@@ -684,15 +683,15 @@ class TestDatabaseApi(SupersetTestCase):
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
@mock.patch("superset.commands.database.create.is_feature_enabled")
def test_cascade_delete_ssh_tunnel(
self,
- mock_test_connection_database_command_run,
- mock_get_all_schema_names,
mock_create_is_feature_enabled,
+ mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_test_connection_database_command_run,
):
"""
Database API: SSH Tunnel gets deleted if Database gets deleted
@@ -739,16 +738,16 @@ class TestDatabaseApi(SupersetTestCase):
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
@mock.patch("superset.extensions.db.session.rollback")
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self,
- mock_rollback,
- mock_test_connection_database_command_run,
- mock_create_is_feature_enabled,
mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_create_is_feature_enabled,
+ mock_test_connection_database_command_run,
+ mock_rollback,
):
"""
Database API: Test rollback is called if SSH Tunnel creation fails
@@ -788,14 +787,14 @@ class TestDatabaseApi(SupersetTestCase):
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
def test_get_database_returns_related_ssh_tunnel(
self,
- mock_test_connection_database_command_run,
- mock_create_is_feature_enabled,
mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_create_is_feature_enabled,
+ mock_test_connection_database_command_run,
):
"""
Database API: Test GET Database returns its related SSH Tunnel
@@ -842,13 +841,13 @@ class TestDatabaseApi(SupersetTestCase):
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
- @mock.patch(
- "superset.models.core.Database.get_all_schema_names",
- )
+ @mock.patch("superset.models.core.Database.get_all_catalog_names")
+ @mock.patch("superset.models.core.Database.get_all_schema_names")
def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception(
self,
- mock_test_connection_database_command_run,
mock_get_all_schema_names,
+ mock_get_all_catalog_names,
+ mock_test_connection_database_command_run,
):
"""
Database API: Test raises SSHTunneling feature flag not enabled
@@ -1987,13 +1986,13 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
- self.assertEqual(schemas, response["result"])
+ self.assertEqual(schemas, set(response["result"]))
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
- self.assertEqual(schemas, response["result"])
+ self.assertEqual(schemas, set(response["result"]))
def test_database_schemas_not_found(self):
"""
diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py
index 2388b38ff..8979b91c4 100644
--- a/tests/integration_tests/databases/commands_tests.py
+++ b/tests/integration_tests/databases/commands_tests.py
@@ -1095,7 +1095,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
@patch("superset.daos.database.DatabaseDAO.find_by_id")
def test_database_tables_list_with_unknown_database(self, mock_find_by_id):
mock_find_by_id.return_value = None
- command = TablesDatabaseCommand(1, "test", False)
+ command = TablesDatabaseCommand(1, None, "test", False)
with pytest.raises(DatabaseNotFoundError) as excinfo:
command.run()
@@ -1115,7 +1115,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
mock_can_access_database.side_effect = SupersetException("Test Error")
mock_g.user = security_manager.find_user("admin")
- command = TablesDatabaseCommand(database.id, "main", False)
+ command = TablesDatabaseCommand(database.id, None, "main", False)
with pytest.raises(SupersetException) as excinfo:
command.run()
assert str(excinfo.value) == "Test Error"
@@ -1131,7 +1131,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
mock_can_access_database.side_effect = Exception("Test Error")
mock_g.user = security_manager.find_user("admin")
- command = TablesDatabaseCommand(database.id, "main", False)
+ command = TablesDatabaseCommand(database.id, None, "main", False)
with pytest.raises(DatabaseTablesUnexpectedError) as excinfo:
command.run()
assert (
@@ -1154,7 +1154,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
if database.backend == "postgresql" or database.backend == "mysql":
return
- command = TablesDatabaseCommand(database.id, schema_name, False)
+ command = TablesDatabaseCommand(database.id, None, schema_name, False)
result = command.run()
assert result["count"] > 0
diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py
index 708b94987..f21dbf54a 100644
--- a/tests/integration_tests/db_engine_specs/postgres_tests.py
+++ b/tests/integration_tests/db_engine_specs/postgres_tests.py
@@ -531,7 +531,7 @@ def test_get_catalog_names(app_context: AppContext) -> None:
return
with database.get_inspector() as inspector:
- assert PostgresEngineSpec.get_catalog_names(database, inspector) == [
+ assert PostgresEngineSpec.get_catalog_names(database, inspector) == {
"postgres",
"superset",
- ]
+ }
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index b68cb7c05..4dc15a2af 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -343,20 +343,20 @@ class TestDatabaseModel(SupersetTestCase):
main_db = get_example_database()
if main_db.backend == "mysql":
- df = main_db.get_df("SELECT 1", None)
+ df = main_db.get_df("SELECT 1", None, None)
self.assertEqual(df.iat[0, 0], 1)
- df = main_db.get_df("SELECT 1;", None)
+ df = main_db.get_df("SELECT 1;", None, None)
self.assertEqual(df.iat[0, 0], 1)
def test_multi_statement(self):
main_db = get_example_database()
if main_db.backend == "mysql":
- df = main_db.get_df("USE superset; SELECT 1", None)
+ df = main_db.get_df("USE superset; SELECT 1", None, None)
self.assertEqual(df.iat[0, 0], 1)
- df = main_db.get_df("USE superset; SELECT ';';", None)
+ df = main_db.get_df("USE superset; SELECT ';';", None, None)
self.assertEqual(df.iat[0, 0], ";")
@mock.patch("superset.models.core.create_engine")
diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py
index f22847ca5..5eca7b4a6 100644
--- a/tests/integration_tests/security_tests.py
+++ b/tests/integration_tests/security_tests.py
@@ -898,7 +898,7 @@ class TestRolePermission(SupersetTestCase):
db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one()
)
self.assertEqual(changed_table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})")
- self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541
+ self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541
# Test Chart permission changed
slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one()
@@ -956,12 +956,12 @@ class TestRolePermission(SupersetTestCase):
db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one()
)
self.assertEqual(changed_table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})")
- self.assertEqual(changed_table1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # noqa: F541
+ self.assertEqual(changed_table1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541
# Test Chart schema permission changed
slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one()
self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})")
- self.assertEqual(slice1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # noqa: F541
+ self.assertEqual(slice1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541
# cleanup
db.session.delete(slice1)
@@ -1069,7 +1069,7 @@ class TestRolePermission(SupersetTestCase):
self.assertEqual(
changed_table1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})"
)
- self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541
+ self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541
# Test Chart permission changed
slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one()
@@ -1158,9 +1158,9 @@ class TestRolePermission(SupersetTestCase):
with self.client.application.test_request_context():
database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user(
- database, ["1", "2", "3"]
+ database, None, {"1", "2", "3"}
)
- self.assertEqual(schemas, ["1", "2", "3"]) # no changes
+ self.assertEqual(schemas, {"1", "2", "3"}) # no changes
@patch("superset.utils.core.g")
@patch("superset.security.manager.g")
@@ -1171,10 +1171,10 @@ class TestRolePermission(SupersetTestCase):
with self.client.application.test_request_context():
database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user(
- database, ["1", "2", "3"]
+ database, None, {"1", "2", "3"}
)
# temp_schema is not passed in the params
- self.assertEqual(schemas, ["1"])
+ self.assertEqual(schemas, {"1"})
delete_schema_perm("[examples].[1]")
def test_schemas_accessible_by_user_datasource_access(self):
@@ -1183,9 +1183,9 @@ class TestRolePermission(SupersetTestCase):
with self.client.application.test_request_context():
with override_user(security_manager.find_user("gamma")):
schemas = security_manager.get_schemas_accessible_by_user(
- database, ["temp_schema", "2", "3"]
+ database, None, {"temp_schema", "2", "3"}
)
- self.assertEqual(schemas, ["temp_schema"])
+ self.assertEqual(schemas, {"temp_schema"})
def test_schemas_accessible_by_user_datasource_and_schema_access(self):
# User has schema access to the datasource temp_schema.wb_health_population in examples DB.
@@ -1194,9 +1194,9 @@ class TestRolePermission(SupersetTestCase):
database = get_example_database()
with override_user(security_manager.find_user("gamma")):
schemas = security_manager.get_schemas_accessible_by_user(
- database, ["temp_schema", "2", "3"]
+ database, None, {"temp_schema", "2", "3"}
)
- self.assertEqual(schemas, ["temp_schema", "2"])
+ self.assertEqual(schemas, {"temp_schema", "2"})
vm = security_manager.find_permission_view_menu(
"schema_access", "[examples].[2]"
)
diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py
index 942653de1..96019c16c 100644
--- a/tests/integration_tests/sqllab_tests.py
+++ b/tests/integration_tests/sqllab_tests.py
@@ -284,9 +284,15 @@ class TestSqlLab(SupersetTestCase):
# sqlite doesn't support database creation
return
+ catalog = examples_db.get_default_catalog()
sqllab_test_db_schema_permission_view = (
security_manager.add_permission_view_menu(
- "schema_access", f"[{examples_db.name}].[{CTAS_SCHEMA_NAME}]"
+ "schema_access",
+ security_manager.get_schema_perm(
+ examples_db.name,
+ catalog,
+ CTAS_SCHEMA_NAME,
+ ),
)
)
schema_perm_role = security_manager.add_role("SchemaPermission")
diff --git a/tests/unit_tests/commands/databases/create_test.py b/tests/unit_tests/commands/databases/create_test.py
new file mode 100644
index 000000000..405238827
--- /dev/null
+++ b/tests/unit_tests/commands/databases/create_test.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest.mock import MagicMock
+
+import pytest
+from pytest_mock import MockerFixture
+
+from superset.commands.database.create import CreateDatabaseCommand
+from superset.extensions import security_manager
+
+
+@pytest.fixture()
+def database_with_catalog(mocker: MockerFixture) -> MagicMock:
+ """
+ Mock a database with catalogs and schemas.
+ """
+ mocker.patch("superset.commands.database.create.db")
+ mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
+
+ database = mocker.MagicMock()
+ database.database_name = "test_database"
+ database.db_engine_spec.__name__ = "test_engine"
+ database.db_engine_spec.supports_catalog = True
+ database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
+ database.get_all_schema_names.side_effect = [
+ {"schema1", "schema2"},
+ {"schema3", "schema4"},
+ ]
+
+ DatabaseDAO = mocker.patch("superset.commands.database.create.DatabaseDAO")
+ DatabaseDAO.create.return_value = database
+
+ return database
+
+
+@pytest.fixture()
+def database_without_catalog(mocker: MockerFixture) -> MagicMock:
+ """
+ Mock a database without catalogs.
+ """
+ mocker.patch("superset.commands.database.create.db")
+ mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
+
+ database = mocker.MagicMock()
+ database.database_name = "test_database"
+ database.db_engine_spec.__name__ = "test_engine"
+ database.db_engine_spec.supports_catalog = False
+ database.get_all_schema_names.return_value = ["schema1", "schema2"]
+
+ DatabaseDAO = mocker.patch("superset.commands.database.create.DatabaseDAO")
+ DatabaseDAO.create.return_value = database
+
+ return database
+
+
+def test_create_permissions_with_catalog(
+ mocker: MockerFixture,
+ database_with_catalog: MockerFixture,
+) -> None:
+ """
+ Test that permissions are created when a database with a catalog is created.
+ """
+ add_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "add_permission_view_menu",
+ )
+
+ CreateDatabaseCommand(
+ {
+ "database_name": "test_database",
+ "sqlalchemy_uri": "sqlite://",
+ }
+ ).run()
+
+ add_permission_view_menu.assert_has_calls(
+ [
+ mocker.call("catalog_access", "[test_database].[catalog1]"),
+ mocker.call("catalog_access", "[test_database].[catalog2]"),
+ mocker.call("schema_access", "[test_database].[catalog1].[schema1]"),
+ mocker.call("schema_access", "[test_database].[catalog1].[schema2]"),
+ mocker.call("schema_access", "[test_database].[catalog2].[schema3]"),
+ mocker.call("schema_access", "[test_database].[catalog2].[schema4]"),
+ ],
+ any_order=True,
+ )
+
+
+def test_create_permissions_without_catalog(
+ mocker: MockerFixture,
+ database_without_catalog: MockerFixture,
+) -> None:
+ """
+ Test that permissions are created when a database without a catalog is created.
+ """
+ add_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "add_permission_view_menu",
+ )
+
+ CreateDatabaseCommand(
+ {
+ "database_name": "test_database",
+ "sqlalchemy_uri": "sqlite://",
+ }
+ ).run()
+
+ add_permission_view_menu.assert_has_calls(
+ [
+ mocker.call("schema_access", "[test_database].[schema1]"),
+ mocker.call("schema_access", "[test_database].[schema2]"),
+ ],
+ any_order=True,
+ )
diff --git a/tests/unit_tests/commands/databases/tables_test.py b/tests/unit_tests/commands/databases/tables_test.py
new file mode 100644
index 000000000..d9a8583f9
--- /dev/null
+++ b/tests/unit_tests/commands/databases/tables_test.py
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest.mock import MagicMock
+
+import pytest
+from pytest_mock import MockerFixture
+
+from superset.commands.database.tables import TablesDatabaseCommand
+from superset.extensions import security_manager
+from superset.utils.core import DatasourceName
+
+
+@pytest.fixture()
+def database_with_catalog(mocker: MockerFixture) -> MagicMock:
+ """
+ Mock a database with catalogs and schemas.
+ """
+ mocker.patch("superset.commands.database.tables.db")
+
+ database = mocker.MagicMock()
+ database.database_name = "test_database"
+ database.get_all_table_names_in_schema.return_value = [
+ DatasourceName("table1", "schema1", "catalog1"),
+ DatasourceName("table2", "schema1", "catalog1"),
+ ]
+ database.get_all_view_names_in_schema.return_value = [
+ DatasourceName("view1", "schema1", "catalog1"),
+ ]
+
+ DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO")
+ DatabaseDAO.find_by_id.return_value = database
+
+ return database
+
+
+@pytest.fixture()
+def database_without_catalog(mocker: MockerFixture) -> MagicMock:
+ """
+ Mock a database without catalogs but with schemas.
+ """
+ mocker.patch("superset.commands.database.tables.db")
+
+ database = mocker.MagicMock()
+ database.database_name = "test_database"
+ database.get_all_table_names_in_schema.return_value = [
+ DatasourceName("table1", "schema1"),
+ DatasourceName("table2", "schema1"),
+ ]
+ database.get_all_view_names_in_schema.return_value = [
+ DatasourceName("view1", "schema1"),
+ ]
+
+ DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO")
+ DatabaseDAO.find_by_id.return_value = database
+
+ return database
+
+
+def test_tables_with_catalog(
+ mocker: MockerFixture,
+ database_with_catalog: MockerFixture,
+) -> None:
+ """
+ Test that permissions are created when a database with a catalog is created.
+ """
+ get_datasources_accessible_by_user = mocker.patch.object(
+ security_manager,
+ "get_datasources_accessible_by_user",
+ side_effect=[
+ {
+ DatasourceName("table1", "schema1", "catalog1"),
+ DatasourceName("table2", "schema1", "catalog1"),
+ },
+ {DatasourceName("view1", "schema1", "catalog1")},
+ ],
+ )
+
+ db = mocker.patch("superset.commands.database.tables.db")
+ table = mocker.MagicMock()
+ table.name = "table1"
+ table.extra_dict = {"foo": "bar"}
+ db.session.query().filter().options().all.return_value = [table]
+
+ payload = TablesDatabaseCommand(1, "catalog1", "schema1", False).run()
+ assert payload == {
+ "count": 3,
+ "result": [
+ {"value": "table1", "type": "table", "extra": {"foo": "bar"}},
+ {"value": "table2", "type": "table", "extra": None},
+ {"value": "view1", "type": "view"},
+ ],
+ }
+
+ get_datasources_accessible_by_user.assert_has_calls(
+ [
+ mocker.call(
+ database=database_with_catalog,
+ catalog="catalog1",
+ schema="schema1",
+ datasource_names=[
+ DatasourceName("table1", "schema1", "catalog1"),
+ DatasourceName("table2", "schema1", "catalog1"),
+ ],
+ ),
+ mocker.call(
+ database=database_with_catalog,
+ catalog="catalog1",
+ schema="schema1",
+ datasource_names=[
+ DatasourceName("view1", "schema1", "catalog1"),
+ ],
+ ),
+ ],
+ )
+
+ database_with_catalog.get_all_table_names_in_schema.assert_called_with(
+ catalog="catalog1",
+ schema="schema1",
+ force=False,
+ cache=database_with_catalog.table_cache_enabled,
+ cache_timeout=database_with_catalog.table_cache_timeout,
+ )
+
+
+def test_tables_without_catalog(
+ mocker: MockerFixture,
+ database_without_catalog: MockerFixture,
+) -> None:
+ """
+ Test that permissions are created when a database without a catalog is created.
+ """
+ get_datasources_accessible_by_user = mocker.patch.object(
+ security_manager,
+ "get_datasources_accessible_by_user",
+ side_effect=[
+ {
+ DatasourceName("table1", "schema1"),
+ DatasourceName("table2", "schema1"),
+ },
+ {DatasourceName("view1", "schema1")},
+ ],
+ )
+
+ db = mocker.patch("superset.commands.database.tables.db")
+ table = mocker.MagicMock()
+ table.name = "table1"
+ table.extra_dict = {"foo": "bar"}
+ db.session.query().filter().options().all.return_value = [table]
+
+ payload = TablesDatabaseCommand(1, None, "schema1", False).run()
+ assert payload == {
+ "count": 3,
+ "result": [
+ {"value": "table1", "type": "table", "extra": {"foo": "bar"}},
+ {"value": "table2", "type": "table", "extra": None},
+ {"value": "view1", "type": "view"},
+ ],
+ }
+
+ get_datasources_accessible_by_user.assert_has_calls(
+ [
+ mocker.call(
+ database=database_without_catalog,
+ catalog=None,
+ schema="schema1",
+ datasource_names=[
+ DatasourceName("table1", "schema1"),
+ DatasourceName("table2", "schema1"),
+ ],
+ ),
+ mocker.call(
+ database=database_without_catalog,
+ catalog=None,
+ schema="schema1",
+ datasource_names=[
+ DatasourceName("view1", "schema1"),
+ ],
+ ),
+ ],
+ )
+
+ database_without_catalog.get_all_table_names_in_schema.assert_called_with(
+ catalog=None,
+ schema="schema1",
+ force=False,
+ cache=database_without_catalog.table_cache_enabled,
+ cache_timeout=database_without_catalog.table_cache_timeout,
+ )
diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py
new file mode 100644
index 000000000..300efb62e
--- /dev/null
+++ b/tests/unit_tests/commands/databases/update_test.py
@@ -0,0 +1,272 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest.mock import MagicMock
+
+import pytest
+from pytest_mock import MockerFixture
+
+from superset.commands.database.update import UpdateDatabaseCommand
+from superset.extensions import security_manager
+
+
+@pytest.fixture()
+def database_with_catalog(mocker: MockerFixture) -> MagicMock:
+ """
+ Mock a database with catalogs and schemas.
+ """
+ mocker.patch("superset.commands.database.update.db")
+
+ database = mocker.MagicMock()
+ database.database_name = "my_db"
+ database.db_engine_spec.__name__ = "test_engine"
+ database.db_engine_spec.supports_catalog = True
+ database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
+ database.get_all_schema_names.side_effect = [
+ ["schema1", "schema2"],
+ ["schema3", "schema4"],
+ ]
+ database.get_default_catalog.return_value = "catalog2"
+
+ return database
+
+
+@pytest.fixture()
+def database_without_catalog(mocker: MockerFixture) -> MagicMock:
+ """
+ Mock a database without catalogs.
+ """
+ mocker.patch("superset.commands.database.update.db")
+
+ database = mocker.MagicMock()
+ database.database_name = "my_db"
+ database.db_engine_spec.__name__ = "test_engine"
+ database.db_engine_spec.supports_catalog = False
+ database.get_all_schema_names.return_value = ["schema1", "schema2"]
+
+ return database
+
+
+def test_update_with_catalog(
+ mocker: MockerFixture,
+ database_with_catalog: MockerFixture,
+) -> None:
+ """
+ Test that permissions are updated correctly.
+
+ In this test, the database has two catalogs with two schemas each:
+
+ - catalog1
+ - schema1
+ - schema2
+ - catalog2
+ - schema3
+ - schema4
+
+ When update is called, only `catalog2.schema3` has permissions associated with it,
+ so `catalog1.*` and `catalog2.schema4` are added.
+ """
+ DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
+ DatabaseDAO.find_by_id.return_value = database_with_catalog
+ DatabaseDAO.update.return_value = database_with_catalog
+
+ find_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "find_permission_view_menu",
+ )
+ find_permission_view_menu.side_effect = [
+ None, # first catalog is new
+ "[my_db].[catalog2]", # second catalog already exists
+ "[my_db].[catalog2].[schema3]", # first schema already exists
+ None, # second schema is new
+ # these are called when checking for existing perms in [db].[schema] format
+ None,
+ None,
+ ]
+ add_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "add_permission_view_menu",
+ )
+
+ UpdateDatabaseCommand(1, {}).run()
+
+ add_permission_view_menu.assert_has_calls(
+ [
+ # first catalog is added with all schemas
+ mocker.call("catalog_access", "[my_db].[catalog1]"),
+ mocker.call("schema_access", "[my_db].[catalog1].[schema1]"),
+ mocker.call("schema_access", "[my_db].[catalog1].[schema2]"),
+ # second catalog already exists, only `schema4` is added
+ mocker.call("schema_access", "[my_db].[catalog2].[schema4]"),
+ ],
+ )
+
+
+def test_update_without_catalog(
+ mocker: MockerFixture,
+ database_without_catalog: MockerFixture,
+) -> None:
+ """
+ Test that permissions are updated correctly.
+
+ In this test, the database has no catalogs and two schemas:
+
+ - schema1
+ - schema2
+
+ When update is called, only `schema2` has permissions associated with it, so `schema1`
+ is added.
+ """
+ DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
+ DatabaseDAO.find_by_id.return_value = database_without_catalog
+ DatabaseDAO.update.return_value = database_without_catalog
+
+ find_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "find_permission_view_menu",
+ )
+ find_permission_view_menu.side_effect = [
+ None, # schema1 has no permissions
+ "[my_db].[schema2]", # second schema already exists
+ ]
+ add_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "add_permission_view_menu",
+ )
+
+ UpdateDatabaseCommand(1, {}).run()
+
+ add_permission_view_menu.assert_called_with(
+ "schema_access",
+ "[my_db].[schema1]",
+ )
+
+
+def test_rename_with_catalog(
+ mocker: MockerFixture,
+ database_with_catalog: MockerFixture,
+) -> None:
+ """
+ Test that permissions are renamed correctly.
+
+ In this test, the database has two catalogs with two schemas each:
+
+ - catalog1
+ - schema1
+ - schema2
+ - catalog2
+ - schema3
+ - schema4
+
+ When update is called, only `catalog2.schema3` has permissions associated with it,
+ so `catalog1.*` and `catalog2.schema4` are added. Additionally, the database has
+ been renamed from `my_db` to `my_other_db`.
+ """
+ DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
+ original_database = mocker.MagicMock()
+ original_database.database_name = "my_db"
+ DatabaseDAO.find_by_id.return_value = original_database
+ database_with_catalog.database_name = "my_other_db"
+ DatabaseDAO.update.return_value = database_with_catalog
+ DatabaseDAO.get_datasets.return_value = []
+
+ find_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "find_permission_view_menu",
+ )
+ catalog2_pvm = mocker.MagicMock()
+ catalog2_schema3_pvm = mocker.MagicMock()
+ find_permission_view_menu.side_effect = [
+ # these are called when adding the permissions:
+ None, # first catalog is new
+ "[my_db].[catalog2]", # second catalog already exists
+ "[my_db].[catalog2].[schema3]", # first schema already exists
+ None, # second schema is new
+ # these are called when renaming the permissions:
+ catalog2_pvm, # old [my_db].[catalog2]
+ catalog2_schema3_pvm, # old [my_db].[catalog2].[schema3]
+ None, # [my_db].[catalog2].[schema4] doesn't exist
+ ]
+ add_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "add_permission_view_menu",
+ )
+
+ UpdateDatabaseCommand(1, {}).run()
+
+ add_permission_view_menu.assert_has_calls(
+ [
+ # first catalog is added with all schemas with the new DB name
+ mocker.call("catalog_access", "[my_other_db].[catalog1]"),
+ mocker.call("schema_access", "[my_other_db].[catalog1].[schema1]"),
+ mocker.call("schema_access", "[my_other_db].[catalog1].[schema2]"),
+ # second catalog already exists, only `schema4` is added
+ mocker.call("schema_access", "[my_other_db].[catalog2].[schema4]"),
+ ],
+ )
+
+ assert catalog2_pvm.view_menu.name == "[my_other_db].[catalog2]"
+ assert catalog2_schema3_pvm.view_menu.name == "[my_other_db].[catalog2].[schema3]"
+
+
+def test_rename_without_catalog(
+ mocker: MockerFixture,
+ database_without_catalog: MockerFixture,
+) -> None:
+ """
+ Test that permissions are renamed correctly.
+
+ In this test, the database has no catalogs and two schemas:
+
+ - schema1
+ - schema2
+
+ When update is called, only `schema2` has permissions associated with it, so `schema1`
+ is added. Additionally, the database has been renamed from `my_db` to `my_other_db`.
+ """
+ DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
+ original_database = mocker.MagicMock()
+ original_database.database_name = "my_db"
+ DatabaseDAO.find_by_id.return_value = original_database
+ database_without_catalog.database_name = "my_other_db"
+ DatabaseDAO.update.return_value = database_without_catalog
+ DatabaseDAO.get_datasets.return_value = []
+
+ find_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "find_permission_view_menu",
+ )
+ schema2_pvm = mocker.MagicMock()
+ find_permission_view_menu.side_effect = [
+ None, # schema1 has no permissions
+ "[my_db].[schema2]", # second schema already exists
+ None, # [my_db].[schema1] doesn't exist
+ schema2_pvm, # old [my_db].[schema2]
+ ]
+ add_permission_view_menu = mocker.patch.object(
+ security_manager,
+ "add_permission_view_menu",
+ )
+
+ UpdateDatabaseCommand(1, {}).run()
+
+ add_permission_view_menu.assert_called_with(
+ "schema_access",
+ "[my_other_db].[schema1]",
+ )
+
+ assert schema2_pvm.view_menu.name == "[my_other_db].[schema2]"
diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py
index 71b3bff33..3905a15b3 100644
--- a/tests/unit_tests/conftest.py
+++ b/tests/unit_tests/conftest.py
@@ -88,6 +88,8 @@ def app(request: SubRequest) -> Iterator[SupersetApp]:
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
app.config["TESTING"] = True
app.config["RATELIMIT_ENABLED"] = False
+ app.config["CACHE_CONFIG"] = {}
+ app.config["DATA_CACHE_CONFIG"] = {}
# loop over extra configs passed in by tests
# and update the app config
diff --git a/tests/unit_tests/connectors/sqla/models_test.py b/tests/unit_tests/connectors/sqla/models_test.py
index 00b4b0a31..687295baf 100644
--- a/tests/unit_tests/connectors/sqla/models_test.py
+++ b/tests/unit_tests/connectors/sqla/models_test.py
@@ -17,9 +17,11 @@
import pytest
from pytest_mock import MockerFixture
+from sqlalchemy import create_engine
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import OAuth2RedirectError
+from superset.models.core import Database
from superset.superset_typing import QueryObjectDict
@@ -64,3 +66,124 @@ def test_query_bubbles_errors(mocker: MockerFixture) -> None:
}
with pytest.raises(OAuth2RedirectError):
sqla_table.query(query_obj)
+
+
+def test_permissions_without_catalog() -> None:
+ """
+ Test permissions when the table has no catalog.
+ """
+ database = Database(database_name="my_db")
+ sqla_table = SqlaTable(
+ table_name="my_sqla_table",
+ columns=[],
+ metrics=[],
+ database=database,
+ schema="schema1",
+ catalog=None,
+ id=1,
+ )
+
+ assert sqla_table.get_perm() == "[my_db].[my_sqla_table](id:1)"
+ assert sqla_table.get_catalog_perm() is None
+ assert sqla_table.get_schema_perm() == "[my_db].[schema1]"
+
+
+def test_permissions_with_catalog() -> None:
+ """
+ Test permissions when the table with a catalog set.
+ """
+ database = Database(database_name="my_db")
+ sqla_table = SqlaTable(
+ table_name="my_sqla_table",
+ columns=[],
+ metrics=[],
+ database=database,
+ schema="schema1",
+ catalog="db1",
+ id=1,
+ )
+
+ assert sqla_table.get_perm() == "[my_db].[my_sqla_table](id:1)"
+ assert sqla_table.get_catalog_perm() == "[my_db].[db1]"
+ assert sqla_table.get_schema_perm() == "[my_db].[db1].[schema1]"
+
+
+def test_query_datasources_by_name(mocker: MockerFixture) -> None:
+ """
+ Test the `query_datasources_by_name` method.
+ """
+ db = mocker.patch("superset.connectors.sqla.models.db")
+
+ database = Database(database_name="my_db", id=1)
+ sqla_table = SqlaTable(
+ table_name="my_sqla_table",
+ columns=[],
+ metrics=[],
+ database=database,
+ )
+
+ sqla_table.query_datasources_by_name(database, "my_table")
+ db.session.query().filter_by.assert_called_with(
+ database_id=1,
+ table_name="my_table",
+ )
+
+ sqla_table.query_datasources_by_name(database, "my_table", "db1", "schema1")
+ db.session.query().filter_by.assert_called_with(
+ database_id=1,
+ table_name="my_table",
+ catalog="db1",
+ schema="schema1",
+ )
+
+
+def test_query_datasources_by_permissions(mocker: MockerFixture) -> None:
+ """
+ Test the `query_datasources_by_permissions` method.
+ """
+ db = mocker.patch("superset.connectors.sqla.models.db")
+
+ engine = create_engine("sqlite://")
+ database = Database(database_name="my_db", id=1)
+ sqla_table = SqlaTable(
+ table_name="my_sqla_table",
+ columns=[],
+ metrics=[],
+ database=database,
+ )
+
+ sqla_table.query_datasources_by_permissions(database, set(), set(), set())
+ db.session.query().filter_by.assert_called_with(database_id=1)
+ clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
+ assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == ""
+
+
+def test_query_datasources_by_permissions_with_catalog_schema(
+ mocker: MockerFixture,
+) -> None:
+ """
+ Test the `query_datasources_by_permissions` method passing a catalog and schema.
+ """
+ db = mocker.patch("superset.connectors.sqla.models.db")
+
+ engine = create_engine("sqlite://")
+ database = Database(database_name="my_db", id=1)
+ sqla_table = SqlaTable(
+ table_name="my_sqla_table",
+ columns=[],
+ metrics=[],
+ database=database,
+ )
+ sqla_table.query_datasources_by_permissions(
+ database,
+ {"[my_db].[table1](id:1)"},
+ {"[my_db].[db1]"},
+ # pass as list to have deterministic order for test
+ ["[my_db].[db1].[schema1]", "[my_other_db].[schema]"], # type: ignore
+ )
+ clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
+ assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == (
+ "tables.perm IN ('[my_db].[table1](id:1)') OR "
+ "tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR "
+ "tables.catalog_perm IN ('[my_db].[db1]')"
+ )
diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py
index 8b309d573..f3f556e2e 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -2066,3 +2066,101 @@ def test_table_extra_metadata_unauthorized(
}
]
}
+
+
+def test_catalogs(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `catalogs` endpoint.
+ """
+ database = mocker.MagicMock()
+ database.get_all_catalog_names.return_value = {"db1", "db2"}
+ DatabaseDAO = mocker.patch("superset.databases.api.DatabaseDAO")
+ DatabaseDAO.find_by_id.return_value = database
+
+ security_manager = mocker.patch(
+ "superset.databases.api.security_manager",
+ new=mocker.MagicMock(),
+ )
+ security_manager.get_catalogs_accessible_by_user.return_value = {"db2"}
+
+ response = client.get("/api/v1/database/1/catalogs/")
+ assert response.status_code == 200
+ assert response.json == {"result": ["db2"]}
+ database.get_all_catalog_names.assert_called_with(
+ cache=database.catalog_cache_enabled,
+ cache_timeout=database.catalog_cache_timeout,
+ force=False,
+ )
+ security_manager.get_catalogs_accessible_by_user.assert_called_with(
+ database,
+ {"db1", "db2"},
+ )
+
+ response = client.get("/api/v1/database/1/catalogs/?q=(force:!t)")
+ database.get_all_catalog_names.assert_called_with(
+ cache=database.catalog_cache_enabled,
+ cache_timeout=database.catalog_cache_timeout,
+ force=True,
+ )
+
+
+def test_schemas(
+ mocker: MockFixture,
+ client: Any,
+ full_api_access: None,
+) -> None:
+ """
+ Test the `schemas` endpoint.
+ """
+ from superset.databases.api import DatabaseRestApi
+
+ database = mocker.MagicMock()
+ database.get_all_schema_names.return_value = {"schema1", "schema2"}
+ datamodel = mocker.patch.object(DatabaseRestApi, "datamodel")
+ datamodel.get.return_value = database
+
+ security_manager = mocker.patch(
+ "superset.databases.api.security_manager",
+ new=mocker.MagicMock(),
+ )
+ security_manager.get_schemas_accessible_by_user.return_value = {"schema2"}
+
+ response = client.get("/api/v1/database/1/schemas/")
+ assert response.status_code == 200
+ assert response.json == {"result": ["schema2"]}
+ database.get_all_schema_names.assert_called_with(
+ catalog=None,
+ cache=database.schema_cache_enabled,
+ cache_timeout=database.schema_cache_timeout,
+ force=False,
+ )
+ security_manager.get_schemas_accessible_by_user.assert_called_with(
+ database,
+ None,
+ {"schema1", "schema2"},
+ )
+
+ response = client.get("/api/v1/database/1/schemas/?q=(force:!t)")
+ database.get_all_schema_names.assert_called_with(
+ catalog=None,
+ cache=database.schema_cache_enabled,
+ cache_timeout=database.schema_cache_timeout,
+ force=True,
+ )
+
+ response = client.get("/api/v1/database/1/schemas/?q=(force:!t,catalog:catalog2)")
+ database.get_all_schema_names.assert_called_with(
+ catalog="catalog2",
+ cache=database.schema_cache_enabled,
+ cache_timeout=database.schema_cache_timeout,
+ force=True,
+ )
+ security_manager.get_schemas_accessible_by_user.assert_called_with(
+ database,
+ "catalog2",
+ {"schema1", "schema2"},
+ )
diff --git a/tests/unit_tests/databases/filters_test.py b/tests/unit_tests/databases/filters_test.py
new file mode 100644
index 000000000..a1e51fce2
--- /dev/null
+++ b/tests/unit_tests/databases/filters_test.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from flask_appbuilder.models.sqla.interface import SQLAInterface
+from pytest_mock import MockerFixture
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+
+from superset.databases.filters import can_access_databases, DatabaseFilter
+from superset.extensions import security_manager
+
+
+def test_can_access_databases(mocker: MockerFixture) -> None:
+ """
+ Test the `can_access_databases` function.
+ """
+ mocker.patch.object(
+ security_manager,
+ "user_view_menu_names",
+ side_effect=[
+ {
+ "[my_db].[examples].[public].[table1](id:1)",
+ "[my_other_db].[examples].[public].[table1](id:2)",
+ },
+ {"[my_db].(id:42)", "[my_other_db].(id:43)"},
+ {"[my_db].[examples]", "[my_db].[other]"},
+ {
+ "[my_db].[examples].[information_schema]",
+ "[my_db].[other].[secret]",
+ "[third_db].[schema]",
+ },
+ ],
+ )
+
+ assert can_access_databases("datasource_access") == {"my_db", "my_other_db"}
+ assert can_access_databases("database_access") == {"my_db", "my_other_db"}
+ assert can_access_databases("catalog_access") == {"my_db"}
+ assert can_access_databases("schema_access") == {"my_db", "third_db"}
+
+
+def test_database_filter_full_db_access(mocker: MockerFixture) -> None:
+ """
+ Test the `DatabaseFilter` class when the user has full database access.
+
+ In this case the query should be returned unmodified.
+ """
+ from superset.models.core import Database
+
+ current_app = mocker.patch("superset.databases.filters.current_app")
+ current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False}
+ mocker.patch.object(security_manager, "can_access_all_databases", return_value=True)
+
+ engine = create_engine("sqlite://")
+ Session = sessionmaker(bind=engine)
+ session = Session()
+ query = session.query(Database)
+
+ filter_ = DatabaseFilter("id", SQLAInterface(Database))
+ filtered_query = filter_.apply(query, None)
+
+ assert filtered_query == query
+
+
+def test_database_filter(mocker: MockerFixture) -> None:
+ """
+ Test the `DatabaseFilter` class with specific permissions.
+ """
+ from superset.models.core import Database
+
+ current_app = mocker.patch("superset.databases.filters.current_app")
+ current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False}
+ mocker.patch.object(
+ security_manager,
+ "can_access_all_databases",
+ return_value=False,
+ )
+ mocker.patch.object(
+ security_manager,
+ "user_view_menu_names",
+ side_effect=[
+ # return lists instead of sets to ensure order
+ ["[my_db].(id:42)", "[my_other_db].(id:43)"],
+ ["[my_db].[examples]", "[my_db].[other]"],
+ [
+ "[my_db].[examples].[information_schema]",
+ "[my_db].[other].[secret]",
+ "[third_db].[schema]",
+ ],
+ [
+ "[my_db].[examples].[public].[table1](id:1)",
+ "[my_other_db].[examples].[public].[table1](id:2)",
+ ],
+ ],
+ )
+
+ engine = create_engine("sqlite://")
+ Session = sessionmaker(bind=engine)
+ session = Session()
+ query = session.query(Database)
+
+ filter_ = DatabaseFilter("id", SQLAInterface(Database))
+ filtered_query = filter_.apply(query, None)
+
+ compiled_query = filtered_query.statement.compile(
+ engine,
+ compile_kwargs={"literal_binds": True},
+ )
+ space = " " # pre-commit removes trailing spaces...
+ assert (
+ str(compiled_query)
+ == f"""SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk{space}
+FROM dbs{space}
+WHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')"""
+ )
diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py
index 3bc05ee20..0950dcb43 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -301,3 +301,13 @@ def test_extra_table_metadata(mocker: MockFixture) -> None:
)
warnings.warn.assert_called()
+
+
+def test_get_default_catalog(mocker: MockFixture) -> None:
+ """
+ Test the `get_default_catalog` method.
+ """
+ from superset.db_engine_specs.base import BaseEngineSpec
+
+ database = mocker.MagicMock()
+ assert BaseEngineSpec.get_default_catalog(database) is None
diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py
index bec77b3b7..d955d3ce5 100644
--- a/tests/unit_tests/db_engine_specs/test_postgres.py
+++ b/tests/unit_tests/db_engine_specs/test_postgres.py
@@ -175,3 +175,33 @@ SELECT * FROM some_table;
str(excinfo.value)
== "Users are not allowed to set a search path for security reasons."
)
+
+
+def test_adjust_engine_params() -> None:
+ """
+ Test `adjust_engine_params`.
+
+ The method can be used to adjust the catalog (database) dynamically.
+ """
+ from superset.db_engine_specs.postgres import PostgresEngineSpec
+
+ adjusted = PostgresEngineSpec.adjust_engine_params(
+ make_url("postgresql://user:password@host:5432/dev"),
+ {},
+ catalog="prod",
+ )
+ assert adjusted == (make_url("postgresql://user:password@host:5432/prod"), {})
+
+
+def test_get_default_catalog() -> None:
+ """
+ Test `get_default_catalog`.
+ """
+ from superset.db_engine_specs.postgres import PostgresEngineSpec
+ from superset.models.core import Database
+
+ database = Database(
+ database_name="postgres",
+ sqlalchemy_uri="postgresql://user:password@host:5432/dev",
+ )
+ assert PostgresEngineSpec.get_default_catalog(database) == "dev"
diff --git a/tests/unit_tests/explore/utils_test.py b/tests/unit_tests/explore/utils_test.py
index 9638392a5..2b274106c 100644
--- a/tests/unit_tests/explore/utils_test.py
+++ b/tests/unit_tests/explore/utils_test.py
@@ -272,6 +272,7 @@ def test_query_no_access(mocker: MockFixture, client) -> None:
from superset.models.sql_lab import Query
database = mocker.MagicMock()
+ database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public"
mocker.patch(
query_find_by_id,
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
index ce3ad1822..5ee521fc0 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -229,3 +229,63 @@ def test_get_prequeries(mocker: MockFixture) -> None:
conn.cursor().execute.assert_has_calls(
[mocker.call("set a=1"), mocker.call("set b=2")]
)
+
+
+def test_catalog_cache() -> None:
+ """
+ Test the catalog cache.
+ """
+ database = Database(
+ database_name="db",
+ sqlalchemy_uri="sqlite://",
+ extra=json.dumps({"metadata_cache_timeout": {"catalog_cache_timeout": 10}}),
+ )
+
+ assert database.catalog_cache_enabled
+ assert database.catalog_cache_timeout == 10
+
+
+def test_get_default_catalog() -> None:
+ """
+ Test the `get_default_catalog` method.
+ """
+ database = Database(
+ database_name="db",
+ sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+ )
+
+ assert database.get_default_catalog() == "examples"
+
+
+def test_get_default_schema(mocker: MockFixture) -> None:
+ """
+ Test the `get_default_schema` method.
+ """
+ database = Database(
+ database_name="db",
+ sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+ )
+
+ get_inspector = mocker.patch.object(database, "get_inspector")
+ with get_inspector() as inspector:
+ inspector.default_schema_name = "public"
+
+ assert database.get_default_schema("examples") == "public"
+ get_inspector.assert_called_with(catalog="examples")
+
+
+def test_get_all_catalog_names(mocker: MockFixture) -> None:
+ """
+ Test the `get_all_catalog_names` method.
+ """
+ database = Database(
+ database_name="db",
+ sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+ )
+
+ get_inspector = mocker.patch.object(database, "get_inspector")
+ with get_inspector() as inspector:
+ inspector.bind.execute.return_value = [("examples",), ("other",)]
+
+ assert database.get_all_catalog_names(force=True) == {"examples", "other"}
+ get_inspector.assert_called_with(ssh_tunnel=None)
diff --git a/tests/unit_tests/security/manager_test.py b/tests/unit_tests/security/manager_test.py
index 7ed32b0ab..033446290 100644
--- a/tests/unit_tests/security/manager_test.py
+++ b/tests/unit_tests/security/manager_test.py
@@ -26,7 +26,10 @@ from superset.connectors.sqla.models import Database, SqlaTable
from superset.exceptions import SupersetSecurityException
from superset.extensions import appbuilder
from superset.models.slice import Slice
-from superset.security.manager import query_context_modified, SupersetSecurityManager
+from superset.security.manager import (
+ query_context_modified,
+ SupersetSecurityManager,
+)
from superset.sql_parse import Table
from superset.superset_typing import AdhocMetric
from superset.utils.core import override_user
@@ -208,6 +211,7 @@ def test_raise_for_access_query_default_schema(
SqlaTable.query_datasources_by_name.return_value = []
database = mocker.MagicMock()
+ database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock()
query.database = database
@@ -262,6 +266,7 @@ def test_raise_for_access_jinja_sql(mocker: MockFixture, app_context: None) -> N
SqlaTable.query_datasources_by_name.return_value = []
database = mocker.MagicMock()
+ database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock()
query.database = database
@@ -531,3 +536,28 @@ def test_query_context_modified_mixed_chart(mocker: MockFixture) -> None:
}
query_context.queries = [QueryObject(metrics=requested_metrics)] # type: ignore
assert not query_context_modified(query_context)
+
+
+def test_get_catalog_perm() -> None:
+ """
+ Test the `get_catalog_perm` method.
+ """
+ sm = SupersetSecurityManager(appbuilder)
+
+ assert sm.get_catalog_perm("my_db", None) is None
+ assert sm.get_catalog_perm("my_db", "my_catalog") == "[my_db].[my_catalog]"
+
+
+def test_get_schema_perm() -> None:
+ """
+ Test the `get_schema_perm` method.
+ """
+ sm = SupersetSecurityManager(appbuilder)
+
+ assert sm.get_schema_perm("my_db", None, "my_schema") == "[my_db].[my_schema]"
+ assert (
+ sm.get_schema_perm("my_db", "my_catalog", "my_schema")
+ == "[my_db].[my_catalog].[my_schema]"
+ )
+ assert sm.get_schema_perm("my_db", None, None) is None
+ assert sm.get_schema_perm("my_db", "my_catalog", None) is None
diff --git a/tests/unit_tests/utils/filters_test.py b/tests/unit_tests/utils/filters_test.py
new file mode 100644
index 000000000..b41774bb2
--- /dev/null
+++ b/tests/unit_tests/utils/filters_test.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pytest_mock import MockerFixture
+from sqlalchemy import create_engine
+
+from superset.utils.filters import get_dataset_access_filters
+
+
+def test_get_dataset_access_filters(mocker: MockerFixture) -> None:
+ """
+ Test the `get_dataset_access_filters` function.
+ """
+ from superset.connectors.sqla.models import SqlaTable
+ from superset.extensions import security_manager
+
+ mocker.patch.object(
+ security_manager,
+ "get_accessible_databases",
+ return_value=[1, 3],
+ )
+ mocker.patch.object(
+ security_manager,
+ "user_view_menu_names",
+ side_effect=[
+ {"[db].[catalog1].[schema1].[table1](id:1)"},
+ {"[db].[catalog1].[schema2]"},
+ {"[db].[catalog2]"},
+ ],
+ )
+
+ clause = get_dataset_access_filters(SqlaTable)
+ engine = create_engine("sqlite://")
+ compiled_query = clause.compile(engine, compile_kwargs={"literal_binds": True})
+ assert str(compiled_query) == (
+ "dbs.id IN (1, 3) "
+ "OR tables.perm IN ('[db].[catalog1].[schema1].[table1](id:1)') "
+ "OR tables.catalog_perm IN ('[db].[catalog2]') OR "
+ "tables.schema_perm IN ('[db].[catalog1].[schema2]')"
+ )
diff --git a/tests/unit_tests/utils/test_core.py b/tests/unit_tests/utils/test_core.py
index a8081e0f3..2ebec87c2 100644
--- a/tests/unit_tests/utils/test_core.py
+++ b/tests/unit_tests/utils/test_core.py
@@ -29,6 +29,7 @@ from superset.utils.core import (
DateColumn,
generic_find_constraint_name,
generic_find_fk_constraint_name,
+ get_datasource_full_name,
is_test,
normalize_dttm_col,
parse_boolean_string,
@@ -369,3 +370,29 @@ def test_generic_find_fk_constraint_none_exist():
)
assert result is None
+
+
+def test_get_datasource_full_name():
+ """
+ Test the `get_datasource_full_name` function.
+
+ This is used to build permissions, so it doesn't really return the datasource full
+ name. Instead, it returns a fully qualified table name that includes the database
+ name and schema, with each part wrapped in square brackets.
+ """
+ assert (
+ get_datasource_full_name("db", "table", "catalog", "schema")
+ == "[db].[catalog].[schema].[table]"
+ )
+
+ assert get_datasource_full_name("db", "table", None, None) == "[db].[table]"
+
+ assert (
+ get_datasource_full_name("db", "table", None, "schema")
+ == "[db].[schema].[table]"
+ )
+
+ assert (
+ get_datasource_full_name("db", "table", "catalog", None)
+ == "[db].[catalog].[table]"
+ )
diff --git a/tests/unit_tests/views/database/__init__.py b/tests/unit_tests/views/database/__init__.py
new file mode 100644
index 000000000..13a83393a
--- /dev/null
+++ b/tests/unit_tests/views/database/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/views/database/mixins_test.py b/tests/unit_tests/views/database/mixins_test.py
new file mode 100644
index 000000000..1752d9763
--- /dev/null
+++ b/tests/unit_tests/views/database/mixins_test.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pytest_mock import MockerFixture
+
+from superset.views.database.mixins import DatabaseMixin
+
+
+def test_pre_add_update_with_catalog(mocker: MockerFixture) -> None:
+ """
+ Test the `_pre_add_update` method on a DB with catalog support.
+ """
+ from superset.models.core import Database
+
+ add_permission_view_menu = mocker.patch(
+ "superset.views.database.mixins.security_manager.add_permission_view_menu"
+ )
+
+ database = Database(
+ database_name="my_db",
+ id=42,
+ sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+ )
+ mocker.patch.object(
+ database,
+ "get_all_catalog_names",
+ return_value=["examples", "other"],
+ )
+ mocker.patch.object(
+ database,
+ "get_all_schema_names",
+ side_effect=[
+ ["public", "information_schema"],
+ ["secret"],
+ ],
+ )
+
+ mixin = DatabaseMixin()
+ mixin._pre_add_update(database)
+
+ add_permission_view_menu.assert_has_calls(
+ [
+ mocker.call("database_access", "[my_db].(id:42)"),
+ mocker.call("catalog_access", "[my_db].[examples]"),
+ mocker.call("catalog_access", "[my_db].[other]"),
+ mocker.call("schema_access", "[my_db].[examples].[public]"),
+ mocker.call("schema_access", "[my_db].[examples].[information_schema]"),
+ mocker.call("schema_access", "[my_db].[other].[secret]"),
+ ],
+ any_order=True,
+ )