From e90246fd1fc27d418c37b864ab4cc63a639d4a97 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 6 May 2024 11:41:58 -0400 Subject: [PATCH] feat(SIP-95): permissions for catalogs (#28317) --- superset/commands/database/create.py | 35 +- superset/commands/database/tables.py | 31 +- superset/commands/database/update.py | 204 +++++++-- superset/commands/sql_lab/export.py | 6 +- superset/config.py | 6 +- superset/connectors/sqla/models.py | 63 ++- superset/constants.py | 1 + superset/databases/api.py | 82 +++- superset/databases/filters.py | 23 +- superset/databases/schemas.py | 15 + superset/db_engine_specs/base.py | 24 +- superset/db_engine_specs/bigquery.py | 6 +- superset/db_engine_specs/clickhouse.py | 4 +- superset/db_engine_specs/impala.py | 7 +- superset/db_engine_specs/postgres.py | 34 +- superset/db_engine_specs/presto.py | 6 +- superset/db_engine_specs/snowflake.py | 8 +- superset/extensions/metadb.py | 5 - superset/migrations/shared/catalogs.py | 116 +++++ ...58d051681a3b_add_catalog_perm_to_tables.py | 53 +++ superset/models/core.py | 82 ++-- superset/models/helpers.py | 16 +- superset/models/slice.py | 1 + superset/security/manager.py | 398 ++++++++++++++---- superset/utils/cache.py | 10 +- superset/utils/core.py | 15 +- superset/utils/filters.py | 2 + superset/views/database/mixins.py | 28 +- .../integration_tests/databases/api_tests.py | 115 +++-- .../databases/commands_tests.py | 8 +- .../db_engine_specs/postgres_tests.py | 4 +- tests/integration_tests/model_tests.py | 8 +- tests/integration_tests/security_tests.py | 24 +- tests/integration_tests/sqllab_tests.py | 8 +- .../commands/databases/create_test.py | 128 ++++++ .../commands/databases/tables_test.py | 203 +++++++++ .../commands/databases/update_test.py | 272 ++++++++++++ tests/unit_tests/conftest.py | 2 + .../unit_tests/connectors/sqla/models_test.py | 123 ++++++ tests/unit_tests/databases/api_test.py | 98 +++++ tests/unit_tests/databases/filters_test.py | 128 ++++++ tests/unit_tests/db_engine_specs/test_base.py | 10 + .../db_engine_specs/test_postgres.py | 30 ++ tests/unit_tests/explore/utils_test.py | 1 + tests/unit_tests/models/core_test.py | 60 +++ tests/unit_tests/security/manager_test.py | 32 +- tests/unit_tests/utils/filters_test.py | 54 +++ tests/unit_tests/utils/test_core.py | 27 ++ tests/unit_tests/views/database/__init__.py | 16 + .../unit_tests/views/database/mixins_test.py | 65 +++ 50 files changed, 2381 insertions(+), 316 deletions(-) create mode 100644 superset/migrations/shared/catalogs.py create mode 100644 superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py create mode 100644 tests/unit_tests/commands/databases/create_test.py create mode 100644 tests/unit_tests/commands/databases/tables_test.py create mode 100644 tests/unit_tests/commands/databases/update_test.py create mode 100644 tests/unit_tests/databases/filters_test.py create mode 100644 tests/unit_tests/utils/filters_test.py create mode 100644 tests/unit_tests/views/database/__init__.py create mode 100644 tests/unit_tests/views/database/mixins_test.py diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 4903938eb..b45107ca8 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -97,12 +97,37 @@ class CreateDatabaseCommand(BaseCommand): db.session.commit() - # adding a new database we always want to force refresh schema list - schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) - for schema in schemas: - security_manager.add_permission_view_menu( - "schema_access", security_manager.get_schema_perm(database, schema) + # add catalog/schema permissions + if database.db_engine_spec.supports_catalog: + catalogs = database.get_all_catalog_names( + cache=False, + ssh_tunnel=ssh_tunnel, ) + for catalog in catalogs: + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm( + database.database_name, catalog + ), + ) + else: + # add a dummy catalog for DBs that don't support them + catalogs = [None] + + for catalog in catalogs: + for schema in database.get_all_schema_names( + catalog=catalog, + cache=False, + ssh_tunnel=ssh_tunnel, + ): + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) except ( SSHTunnelInvalidError, diff --git a/superset/commands/database/tables.py b/superset/commands/database/tables.py index 055c0be9a..b16fcfc50 100644 --- a/superset/commands/database/tables.py +++ b/superset/commands/database/tables.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import logging from typing import Any, cast @@ -29,7 +32,6 @@ from superset.daos.database import DatabaseDAO from superset.exceptions import SupersetException from superset.extensions import db, security_manager from superset.models.core import Database -from superset.utils.core import DatasourceName logger = logging.getLogger(__name__) @@ -37,8 +39,15 @@ logger = logging.getLogger(__name__) class TablesDatabaseCommand(BaseCommand): _model: Database - def __init__(self, db_id: int, schema_name: str, force: bool): + def __init__( + self, + db_id: int, + catalog_name: str | None, + schema_name: str, + force: bool, + ): self._db_id = db_id + self._catalog_name = catalog_name self._schema_name = schema_name self._force = force @@ -47,11 +56,11 @@ class TablesDatabaseCommand(BaseCommand): try: tables = security_manager.get_datasources_accessible_by_user( database=self._model, + catalog=self._catalog_name, schema=self._schema_name, datasource_names=sorted( - DatasourceName(*datasource_name) - for datasource_name in self._model.get_all_table_names_in_schema( - catalog=None, + self._model.get_all_table_names_in_schema( + catalog=self._catalog_name, schema=self._schema_name, force=self._force, cache=self._model.table_cache_enabled, @@ -62,11 +71,11 @@ class TablesDatabaseCommand(BaseCommand): views = security_manager.get_datasources_accessible_by_user( database=self._model, + catalog=self._catalog_name, schema=self._schema_name, datasource_names=sorted( - DatasourceName(*datasource_name) - for datasource_name in self._model.get_all_view_names_in_schema( - catalog=None, + self._model.get_all_view_names_in_schema( + catalog=self._catalog_name, schema=self._schema_name, force=self._force, cache=self._model.table_cache_enabled, @@ -81,11 +90,15 @@ class TablesDatabaseCommand(BaseCommand): db.session.query(SqlaTable) .filter( SqlaTable.database_id == self._model.id, + SqlaTable.catalog == self._catalog_name, SqlaTable.schema == self._schema_name, ) .options( load_only( - SqlaTable.schema, SqlaTable.table_name, SqlaTable.extra + SqlaTable.catalog, + SqlaTable.schema, + SqlaTable.table_name, + SqlaTable.extra, ), lazyload(SqlaTable.columns), lazyload(SqlaTable.metrics), diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 5e0968954..c59984238 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -18,10 +18,9 @@ from __future__ import annotations import logging -from typing import Any, Optional +from typing import Any from flask_appbuilder.models.sqla import Model -from marshmallow import ValidationError from superset import is_feature_enabled, security_manager from superset.commands.base import BaseCommand @@ -50,12 +49,12 @@ logger = logging.getLogger(__name__) class UpdateDatabaseCommand(BaseCommand): - _model: Optional[Database] + _model: Database | None def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id - self._model: Optional[Database] = None + self._model: Database | None = None def run(self) -> Model: self._model = DatabaseDAO.find_by_id(self._model_id) @@ -85,7 +84,7 @@ class UpdateDatabaseCommand(BaseCommand): ) database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) - self._refresh_schemas(database, original_database_name, ssh_tunnel) + self._refresh_catalogs(database, original_database_name, ssh_tunnel) except SSHTunnelError as ex: # allow exception to bubble for debugbing information raise ex @@ -121,67 +120,200 @@ class UpdateDatabaseCommand(BaseCommand): ssh_tunnel_properties, ).run() - def _refresh_schemas( + def _get_catalog_names( self, database: Database, - original_database_name: str, - ssh_tunnel: Optional[SSHTunnel], - ) -> None: + ssh_tunnel: SSHTunnel | None, + ) -> set[str]: """ - Add permissions for any new schemas. + Helper method to load catalogs. + + This method captures a generic exception, since errors could potentially come + from any of the 50+ database drivers we support. """ try: - schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel) + return database.get_all_catalog_names( + force=True, + ssh_tunnel=ssh_tunnel, + ) except Exception as ex: db.session.rollback() raise DatabaseConnectionFailedError() from ex - for schema in schemas: - original_vm = security_manager.get_schema_perm( + def _get_schema_names( + self, + database: Database, + catalog: str | None, + ssh_tunnel: SSHTunnel | None, + ) -> set[str]: + """ + Helper method to load schemas. + + This method captures a generic exception, since errors could potentially come + from any of the 50+ database drivers we support. + """ + try: + return database.get_all_schema_names( + force=True, + catalog=catalog, + ssh_tunnel=ssh_tunnel, + ) + except Exception as ex: + db.session.rollback() + raise DatabaseConnectionFailedError() from ex + + def _refresh_catalogs( + self, + database: Database, + original_database_name: str, + ssh_tunnel: SSHTunnel | None, + ) -> None: + """ + Add permissions for any new catalogs and schemas. + """ + catalogs = ( + self._get_catalog_names(database, ssh_tunnel) + if database.db_engine_spec.supports_catalog + else [None] + ) + + for catalog in catalogs: + schemas = self._get_schema_names(database, catalog, ssh_tunnel) + + if catalog: + perm = security_manager.get_catalog_perm( + original_database_name, + catalog, + ) + existing_pvm = security_manager.find_permission_view_menu( + "catalog_access", + perm, + ) + if not existing_pvm: + # new catalog + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm( + database.database_name, + catalog, + ), + ) + for schema in schemas: + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) + continue + + # add possible new schemas in catalog + self._refresh_schemas( + database, original_database_name, + catalog, + schemas, + ) + + if original_database_name != database.database_name: + self._rename_database_in_permissions( + database, + original_database_name, + catalog, + schemas, + ) + + db.session.commit() + + def _refresh_schemas( + self, + database: Database, + original_database_name: str, + catalog: str | None, + schemas: set[str], + ) -> None: + """ + Add new schemas that don't have permissions yet. + """ + for schema in schemas: + perm = security_manager.get_schema_perm( + original_database_name, + catalog, schema, ) existing_pvm = security_manager.find_permission_view_menu( "schema_access", - original_vm, + perm, ) if not existing_pvm: - # new schema - security_manager.add_permission_view_menu( - "schema_access", - security_manager.get_schema_perm(database.database_name, schema), + new_name = security_manager.get_schema_perm( + database.database_name, + catalog, + schema, ) - continue + security_manager.add_permission_view_menu("schema_access", new_name) - if original_database_name == database.database_name: - continue + def _rename_database_in_permissions( + self, + database: Database, + original_database_name: str, + catalog: str | None, + schemas: set[str], + ) -> None: + new_name = security_manager.get_catalog_perm( + database.database_name, + catalog, + ) - # rename existing schema permission - existing_pvm.view_menu.name = security_manager.get_schema_perm( + # rename existing catalog permission + if catalog: + perm = security_manager.get_catalog_perm( + original_database_name, + catalog, + ) + existing_pvm = security_manager.find_permission_view_menu( + "catalog_access", + perm, + ) + if existing_pvm: + existing_pvm.view_menu.name = new_name + + for schema in schemas: + new_name = security_manager.get_schema_perm( database.database_name, + catalog, schema, ) + # rename existing schema permission + perm = security_manager.get_schema_perm( + original_database_name, + catalog, + schema, + ) + existing_pvm = security_manager.find_permission_view_menu( + "schema_access", + perm, + ) + if existing_pvm: + existing_pvm.view_menu.name = new_name + # rename permissions on datasets and charts for dataset in DatabaseDAO.get_datasets( database.id, - catalog=None, + catalog=catalog, schema=schema, ): - dataset.schema_perm = existing_pvm.view_menu.name + dataset.schema_perm = new_name for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]: - chart.schema_perm = existing_pvm.view_menu.name - - db.session.commit() + chart.schema_perm = new_name def validate(self) -> None: - exceptions: list[ValidationError] = [] - database_name: Optional[str] = self._properties.get("database_name") - if database_name: - # Check database_name uniqueness + if database_name := self._properties.get("database_name"): if not DatabaseDAO.validate_update_uniqueness( - self._model_id, database_name + self._model_id, + database_name, ): - exceptions.append(DatabaseExistsValidationError()) - if exceptions: - raise DatabaseInvalidError(exceptions=exceptions) + raise DatabaseInvalidError(exceptions=[DatabaseExistsValidationError()]) diff --git a/superset/commands/sql_lab/export.py b/superset/commands/sql_lab/export.py index aa6050f27..bfa739054 100644 --- a/superset/commands/sql_lab/export.py +++ b/superset/commands/sql_lab/export.py @@ -126,7 +126,11 @@ class SqlResultExportCommand(BaseCommand): }: # remove extra row from `increased_limit` limit -= 1 - df = self._query.database.get_df(sql, self._query.schema)[:limit] + df = self._query.database.get_df( + sql, + self._query.catalog, + self._query.schema, + )[:limit] csv_data = csv.df_to_escaped_csv(df, index=False, **config["CSV_EXPORT"]) diff --git a/superset/config.py b/superset/config.py index 9388edbe8..7851938b7 100644 --- a/superset/config.py +++ b/superset/config.py @@ -564,9 +564,9 @@ IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None # # Takes as a parameter the common bootstrap payload before transformations. # Returns a dict containing data that should be added or overridden to the payload. -COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[[dict[str, Any]], dict[str, Any]] = ( # noqa: E731 - lambda data: {} -) # default: empty dict +COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[ # noqa: E731 + [dict[str, Any]], dict[str, Any] +] = lambda data: {} # EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes # example code for "My custom warm to hot" color scheme diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 719d5af58..12fbdc3bd 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -211,6 +211,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable= params = Column(String(1000)) perm = Column(String(1000)) schema_perm = Column(String(1000)) + catalog_perm = Column(String(1000), nullable=True, default=None) is_managed_externally = Column(Boolean, nullable=False, default=False) external_url = Column(Text, nullable=True) @@ -1261,9 +1262,20 @@ class SqlaTable( anchor = f'{name}' return Markup(anchor) + def get_catalog_perm(self) -> str | None: + """Returns catalog permission if present, database one otherwise.""" + return security_manager.get_catalog_perm( + self.database.database_name, + self.catalog, + ) + def get_schema_perm(self) -> str | None: """Returns schema permission if present, database one otherwise.""" - return security_manager.get_schema_perm(self.database, self.schema or None) + return security_manager.get_schema_perm( + self.database.database_name, + self.catalog, + self.schema or None, + ) def get_perm(self) -> str: """ @@ -1282,7 +1294,10 @@ class SqlaTable( @property def full_name(self) -> str: return utils.get_datasource_full_name( - self.database, self.table_name, schema=self.schema + self.database, + self.table_name, + catalog=self.catalog, + schema=self.schema, ) @property @@ -1736,7 +1751,10 @@ class SqlaTable( try: df = self.database.get_df( - sql, self.schema or None, mutator=assign_column_label + sql, + None, + self.schema or None, + mutator=assign_column_label, ) except (SupersetErrorException, SupersetErrorsException) as ex: # SupersetError(s) exception should not be captured; instead, they should @@ -1870,34 +1888,45 @@ class SqlaTable( cls, database: Database, datasource_name: str, + catalog: str | None = None, schema: str | None = None, ) -> list[SqlaTable]: - query = ( - db.session.query(cls) - .filter_by(database_id=database.id) - .filter_by(table_name=datasource_name) - ) + filters = { + "database_id": database.id, + "table_name": datasource_name, + } + if catalog: + filters["catalog"] = catalog if schema: - query = query.filter_by(schema=schema) - return query.all() + filters["schema"] = schema + + return db.session.query(cls).filter_by(**filters).all() @classmethod def query_datasources_by_permissions( # pylint: disable=invalid-name cls, database: Database, permissions: set[str], + catalog_perms: set[str], schema_perms: set[str], ) -> list[SqlaTable]: - # TODO(hughhhh): add unit test + # remove empty sets from the query, since SQLAlchemy produces horrible SQL for + # Model.column._in({}): + # + # table.column IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1) + filters = [ + method.in_(perms) + for method, perms in zip( + (SqlaTable.perm, SqlaTable.schema_perm, SqlaTable.catalog_perm), + (permissions, schema_perms, catalog_perms), + ) + if perms + ] + return ( db.session.query(cls) .filter_by(database_id=database.id) - .filter( - or_( - SqlaTable.perm.in_(permissions), - SqlaTable.schema_perm.in_(schema_perms), - ) - ) + .filter(or_(*filters)) .all() ) diff --git a/superset/constants.py b/superset/constants.py index bac83ae55..42cde6115 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -132,6 +132,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = { "related_objects": "read", "tables": "read", "schemas": "read", + "catalogs": "read", "select_star": "read", "table_metadata": "read", "table_metadata_deprecated": "read", diff --git a/superset/databases/api.py b/superset/databases/api.py index 46a0b8cc6..116ac9f46 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -73,10 +73,12 @@ from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO from superset.databases.decorators import check_table_access from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter from superset.databases.schemas import ( + CatalogsResponseSchema, ColumnarMetadataUploadFilePostSchema, ColumnarUploadPostSchema, CSVMetadataUploadFilePostSchema, CSVUploadPostSchema, + database_catalogs_query_schema, database_schemas_query_schema, database_tables_query_schema, DatabaseConnectionSchema, @@ -146,6 +148,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "table_extra_metadata", "table_extra_metadata_deprecated", "select_star", + "catalogs", "schemas", "test_connection", "related_objects", @@ -266,6 +269,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): edit_model_schema = DatabasePutSchema() apispec_parameter_schemas = { + "database_catalogs_query_schema": database_catalogs_query_schema, "database_schemas_query_schema": database_schemas_query_schema, "database_tables_query_schema": database_tables_query_schema, "get_export_ids_schema": get_export_ids_schema, @@ -273,6 +277,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): openapi_spec_tag = "Database" openapi_spec_component_schemas = ( + CatalogsResponseSchema, ColumnarUploadPostSchema, CSVUploadPostSchema, DatabaseConnectionSchema, @@ -604,6 +609,70 @@ class DatabaseRestApi(BaseSupersetModelRestApi): ) return self.response_422(message=str(ex)) + @expose("//catalogs/") + @protect() + @safe + @rison(database_catalogs_query_schema) + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".catalogs", + log_to_statsd=False, + ) + def catalogs(self, pk: int, **kwargs: Any) -> FlaskResponse: + """Get all catalogs from a database. + --- + get: + summary: Get all catalogs from a database + parameters: + - in: path + schema: + type: integer + name: pk + description: The database id + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/database_catalogs_query_schema' + responses: + 200: + description: A List of all catalogs from the database + content: + application/json: + schema: + $ref: "#/components/schemas/CatalogsResponseSchema" + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + database = DatabaseDAO.find_by_id(pk) + if not database: + return self.response_404() + try: + catalogs = database.get_all_catalog_names( + cache=database.catalog_cache_enabled, + cache_timeout=database.catalog_cache_timeout or None, + force=kwargs["rison"].get("force", False), + ) + catalogs = security_manager.get_catalogs_accessible_by_user( + database, + catalogs, + ) + return self.response(200, result=list(catalogs)) + except OperationalError: + return self.response( + 500, + message="There was an error connecting to the database", + ) + except SupersetException as ex: + return self.response(ex.status, message=ex.message) + @expose("//schemas/") @protect() @safe @@ -650,13 +719,19 @@ class DatabaseRestApi(BaseSupersetModelRestApi): if not database: return self.response_404() try: + catalog = kwargs["rison"].get("catalog") schemas = database.get_all_schema_names( + catalog=catalog, cache=database.schema_cache_enabled, cache_timeout=database.schema_cache_timeout or None, force=kwargs["rison"].get("force", False), ) - schemas = security_manager.get_schemas_accessible_by_user(database, schemas) - return self.response(200, result=schemas) + schemas = security_manager.get_schemas_accessible_by_user( + database, + catalog, + schemas, + ) + return self.response(200, result=list(schemas)) except OperationalError: return self.response( 500, message="There was an error connecting to the database" @@ -718,10 +793,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ force = kwargs["rison"].get("force", False) + catalog_name = kwargs["rison"].get("catalog_name") schema_name = kwargs["rison"].get("schema_name", "") try: - command = TablesDatabaseCommand(pk, schema_name, force) + command = TablesDatabaseCommand(pk, catalog_name, schema_name, force) payload = command.run() return self.response(200, **payload) except DatabaseNotFoundError: diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 33748da4b..420a55fd2 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -28,11 +28,12 @@ from superset.models.core import Database from superset.views.base import BaseFilter -def can_access_databases( - view_menu_name: str, -) -> set[str]: +def can_access_databases(view_menu_name: str) -> set[str]: + """ + Return names of databases available in `view_menu_name`. + """ return { - security_manager.unpack_database_and_schema(vm).database + vm.split(".")[0][1:-1] for vm in security_manager.user_view_menu_names(view_menu_name) } @@ -56,17 +57,21 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods # We can proceed with default filtering now if security_manager.can_access_all_databases(): return query - database_perms = security_manager.user_view_menu_names("database_access") - schema_access_databases = can_access_databases("schema_access") + database_perms = security_manager.user_view_menu_names("database_access") + catalog_access_databases = can_access_databases("catalog_access") + schema_access_databases = can_access_databases("schema_access") datasource_access_databases = can_access_databases("datasource_access") + database_names = sorted( + catalog_access_databases + | schema_access_databases + | datasource_access_databases + ) return query.filter( or_( self.model.perm.in_(database_perms), - self.model.database_name.in_( - [*schema_access_databases, *datasource_access_databases] - ), + self.model.database_name.in_(database_names), ) ) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 074139857..4318a1b48 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -56,6 +56,14 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri from superset.utils.core import markdown, parse_ssl_cert database_schemas_query_schema = { + "type": "object", + "properties": { + "force": {"type": "boolean"}, + "catalog": {"type": "string"}, + }, +} + +database_catalogs_query_schema = { "type": "object", "properties": {"force": {"type": "boolean"}}, } @@ -65,6 +73,7 @@ database_tables_query_schema = { "properties": { "force": {"type": "boolean"}, "schema_name": {"type": "string"}, + "catalog_name": {"type": "string"}, }, "required": ["schema_name"], } @@ -712,6 +721,12 @@ class SchemasResponseSchema(Schema): ) +class CatalogsResponseSchema(Schema): + result = fields.List( + fields.String(metadata={"description": "A database catalog name"}) + ) + + class DatabaseTablesResponse(Schema): extra = fields.Dict( metadata={"description": "Extra data used to specify column metadata"} diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 3cc131512..4ee4e4dc0 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -404,6 +404,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # Does the DB support catalogs? A catalog here is a group of schemas, and has # different names depending on the DB: BigQuery calles it a "project", Postgres calls # it a "database", Trino calls it a "catalog", etc. + # + # When this is changed to true in a DB engine spec it MUST support the + # `get_default_catalog` and `get_catalog_names` methods. In addition, you MUST write + # a database migration updating any existing schema permissions. supports_catalog = False # Can the catalog be changed on a per-query basis? @@ -638,10 +642,20 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return driver in cls.drivers + @classmethod + def get_default_catalog( + cls, + database: Database, # pylint: disable=unused-argument + ) -> str | None: + """ + Return the default catalog for a given database. + """ + return None + @classmethod def get_default_schema(cls, database: Database, catalog: str | None) -> str | None: """ - Return the default schema in a given database. + Return the default schema for a catalog in a given database. """ with database.get_inspector(catalog=catalog) as inspector: return inspector.default_schema_name @@ -1412,24 +1426,24 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Get all catalogs from database. This needs to be implemented per database, since SQLAlchemy doesn't offer an abstraction. """ - return [] + return set() @classmethod - def get_schema_names(cls, inspector: Inspector) -> list[str]: + def get_schema_names(cls, inspector: Inspector) -> set[str]: """ Get all schemas from database :param inspector: SqlAlchemy inspector :return: All schemas in the database """ - return sorted(inspector.get_schema_names()) + return set(inspector.get_schema_names()) @classmethod def get_table_names( # pylint: disable=unused-argument diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 8a2612f5b..ca52bd51c 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -127,7 +127,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met allows_hidden_cc_in_orderby = True - supports_catalog = True + supports_catalog = False """ https://www.python.org/dev/peps/pep-0249/#arraysize @@ -464,7 +464,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Get all catalogs. @@ -475,7 +475,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met client = cls._get_client(engine) projects = client.list_projects() - return sorted(project.project_id for project in projects) + return {project.project_id for project in projects} @classmethod def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py index 4346f77d6..aa5399929 100644 --- a/superset/db_engine_specs/clickhouse.py +++ b/superset/db_engine_specs/clickhouse.py @@ -278,7 +278,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec): @classmethod def get_function_names(cls, database: Database) -> list[str]: - # pylint: disable=import-outside-toplevel,import-error + # pylint: disable=import-outside-toplevel, import-error from clickhouse_connect.driver.exceptions import ClickHouseError if cls._function_names: @@ -340,7 +340,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec): def validate_parameters( cls, properties: BasicPropertiesType ) -> list[SupersetError]: - # pylint: disable=import-outside-toplevel,import-error + # pylint: disable=import-outside-toplevel, import-error from clickhouse_connect.driver import default_port parameters = properties.get("parameters", {}) diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index 1d3ec4e9e..d7d1862aa 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -74,13 +74,12 @@ class ImpalaEngineSpec(BaseEngineSpec): return None @classmethod - def get_schema_names(cls, inspector: Inspector) -> list[str]: - schemas = [ + def get_schema_names(cls, inspector: Inspector) -> set[str]: + return { row[0] for row in inspector.engine.execute("SHOW SCHEMAS") if not row[0].startswith("_") - ] - return schemas + } @classmethod def has_implicit_cancel(cls) -> bool: diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index ce87aa1f9..bba2157e0 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -101,8 +101,6 @@ class PostgresBaseEngineSpec(BaseEngineSpec): engine = "" engine_name = "PostgreSQL" - supports_catalog = True - _time_grain_expressions = { None: "{col}", TimeGrain.SECOND: "DATE_TRUNC('second', {col})", @@ -199,7 +197,10 @@ class PostgresBaseEngineSpec(BaseEngineSpec): class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): engine = "postgresql" engine_aliases = {"postgres"} + supports_dynamic_schema = True + supports_catalog = True + supports_dynamic_catalog = True default_driver = "psycopg2" sqlalchemy_uri_placeholder = ( @@ -296,6 +297,29 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): return super().get_default_schema_for_query(database, query) + @classmethod + def adjust_engine_params( + cls, + uri: URL, + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: + """ + Set the catalog (database). + """ + if catalog: + uri = uri.set(database=catalog) + + return uri, connect_args + + @classmethod + def get_default_catalog(cls, database: Database) -> str | None: + """ + Return the default catalog for a given database. + """ + return database.url_object.database + @classmethod def get_prequeries( cls, @@ -346,13 +370,13 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Return all catalogs. In Postgres, a catalog is called a "database". """ - return sorted( + return { catalog for (catalog,) in inspector.bind.execute( """ @@ -360,7 +384,7 @@ SELECT datname FROM pg_database WHERE datistemplate = false; """ ) - ) + } @classmethod def get_table_names( diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 34c47eb52..c59b1b1d1 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -648,6 +648,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): engine_name = "Presto" allows_alias_to_source_column = False + supports_catalog = False + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __( @@ -815,11 +817,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): cls, database: Database, inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Get all catalogs. """ - return [catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")] + return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} @classmethod def _create_column_info( diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 83d382cda..9a82cfcca 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -85,7 +85,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): sqlalchemy_uri_placeholder = "snowflake://" supports_dynamic_schema = True - supports_catalog = True + supports_catalog = False _time_grain_expressions = { None: "{col}", @@ -174,18 +174,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): cls, database: "Database", inspector: Inspector, - ) -> list[str]: + ) -> set[str]: """ Return all catalogs. In Snowflake, a catalog is called a "database". """ - return sorted( + return { catalog for (catalog,) in inspector.bind.execute( "SELECT DATABASE_NAME from information_schema.databases" ) - ) + } @classmethod def epoch_to_dttm(cls) -> str: diff --git a/superset/extensions/metadb.py b/superset/extensions/metadb.py index 2d8444cc9..fd697aea8 100644 --- a/superset/extensions/metadb.py +++ b/superset/extensions/metadb.py @@ -270,11 +270,6 @@ class SupersetShillelaghAdapter(Adapter): self.schema = parts.pop(-1) if parts else None self.catalog = parts.pop(-1) if parts else None - if self.catalog: - # TODO (betodealmeida): when SIP-95 is implemented we should check to see if - # the database has multi-catalog enabled, and if so, give access. - raise NotImplementedError("Catalogs are not currently supported") - # If the table has a single integer primary key we use that as the row ID in order # to perform updates and deletes. Otherwise we can only do inserts and selects. self._rowid: str | None = None diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py new file mode 100644 index 000000000..5d01ecfbf --- /dev/null +++ b/superset/migrations/shared/catalogs.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging + +from alembic import op + +from superset import db, security_manager +from superset.daos.database import DatabaseDAO +from superset.models.core import Database + +logger = logging.getLogger(__name__) + + +def upgrade_schema_perms(engine: str | None = None) -> None: + """ + Update schema permissions to include the catalog part. + + Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the + introduction of catalogs, any existing permissions need to be renamed to include the + catalog: `[db].[catalog].[schema]`. + """ + bind = op.get_bind() + session = db.Session(bind=bind) + for database in session.query(Database).all(): + db_engine_spec = database.db_engine_spec + if ( + engine and db_engine_spec.engine != engine + ) or not db_engine_spec.supports_catalog: + continue + + catalog = database.get_default_catalog() + ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) + for schema in database.get_all_schema_names( + catalog=catalog, + cache=False, + ssh_tunnel=ssh_tunnel, + ): + perm = security_manager.get_schema_perm( + database.database_name, + None, + schema, + ) + existing_pvm = security_manager.find_permission_view_menu( + "schema_access", + perm, + ) + if existing_pvm: + existing_pvm.view_menu.name = security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ) + + session.commit() + + +def downgrade_schema_perms(engine: str | None = None) -> None: + """ + Update schema permissions to not have the catalog part. + + Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the + introduction of catalogs, any existing permissions need to be renamed to include the + catalog: `[db].[catalog].[schema]`. + + This helped function reverts the process. + """ + bind = op.get_bind() + session = db.Session(bind=bind) + for database in session.query(Database).all(): + db_engine_spec = database.db_engine_spec + if ( + engine and db_engine_spec.engine != engine + ) or not db_engine_spec.supports_catalog: + continue + + catalog = database.get_default_catalog() + ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) + for schema in database.get_all_schema_names( + catalog=catalog, + cache=False, + ssh_tunnel=ssh_tunnel, + ): + perm = security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ) + existing_pvm = security_manager.find_permission_view_menu( + "schema_access", + perm, + ) + if existing_pvm: + existing_pvm.view_menu.name = security_manager.get_schema_perm( + database.database_name, + None, + schema, + ) + + session.commit() diff --git a/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py new file mode 100644 index 000000000..17b33e1d0 --- /dev/null +++ b/superset/migrations/versions/2024-05-01_10-52_58d051681a3b_add_catalog_perm_to_tables.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add catalog_perm to tables + +Revision ID: 58d051681a3b +Revises: 4a33124c18ad +Create Date: 2024-05-01 10:52:31.458433 + +""" + +import sqlalchemy as sa +from alembic import op + +from superset.migrations.shared.catalogs import ( + downgrade_schema_perms, + upgrade_schema_perms, +) + +# revision identifiers, used by Alembic. +revision = "58d051681a3b" +down_revision = "4a33124c18ad" + + +def upgrade(): + op.add_column( + "tables", + sa.Column("catalog_perm", sa.String(length=1000), nullable=True), + ) + op.add_column( + "slices", + sa.Column("catalog_perm", sa.String(length=1000), nullable=True), + ) + upgrade_schema_perms(engine="postgresql") + + +def downgrade(): + op.drop_column("slices", "catalog_perm") + op.drop_column("tables", "catalog_perm") + downgrade_schema_perms(engine="postgresql") diff --git a/superset/models/core.py b/superset/models/core.py index 9a4a1de40..fe486bf2b 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -78,7 +78,7 @@ from superset.sql_parse import Table from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType from superset.utils import cache as cache_util, core as utils from superset.utils.backports import StrEnum -from superset.utils.core import get_username +from superset.utils.core import DatasourceName, get_username from superset.utils.oauth2 import get_oauth2_access_token config = app.config @@ -313,6 +313,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable def metadata_cache_timeout(self) -> dict[str, Any]: return self.get_extra().get("metadata_cache_timeout", {}) + @property + def catalog_cache_enabled(self) -> bool: + return "catalog_cache_timeout" in self.metadata_cache_timeout + + @property + def catalog_cache_timeout(self) -> int | None: + return self.metadata_cache_timeout.get("catalog_cache_timeout") + @property def schema_cache_enabled(self) -> bool: return "schema_cache_timeout" in self.metadata_cache_timeout @@ -549,6 +557,18 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable yield conn + def get_default_catalog(self) -> str | None: + """ + Return the default configured catalog for the database. + """ + return self.db_engine_spec.get_default_catalog(self) + + def get_default_schema(self, catalog: str | None) -> str | None: + """ + Return the default schema for the database. + """ + return self.db_engine_spec.get_default_schema(self, catalog) + def get_default_schema_for_query(self, query: Query) -> str | None: """ Return the default schema for a given query. @@ -706,19 +726,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable key="db:{self.id}:schema:{schema}:table_list", cache=cache_manager.cache, ) - def get_all_table_names_in_schema( # pylint: disable=unused-argument + def get_all_table_names_in_schema( self, catalog: str | None, schema: str, - cache: bool = False, - cache_timeout: int | None = None, - force: bool = False, - ) -> set[tuple[str, str]]: + ) -> set[DatasourceName]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. + :param catalog: optional catalog name :param schema: schema name :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache @@ -728,7 +746,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable try: with self.get_inspector(catalog=catalog, schema=schema) as inspector: return { - (table, schema) + DatasourceName(table, schema, catalog) for table in self.db_engine_spec.get_table_names( database=self, inspector=inspector, @@ -742,19 +760,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable key="db:{self.id}:schema:{schema}:view_list", cache=cache_manager.cache, ) - def get_all_view_names_in_schema( # pylint: disable=unused-argument + def get_all_view_names_in_schema( self, catalog: str | None, schema: str, - cache: bool = False, - cache_timeout: int | None = None, - force: bool = False, - ) -> set[tuple[str, str]]: + ) -> set[DatasourceName]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. + :param catalog: optional catalog name :param schema: schema name :param cache: whether cache is enabled for the function :param cache_timeout: timeout in seconds for the cache @@ -764,7 +780,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable try: with self.get_inspector(catalog=catalog, schema=schema) as inspector: return { - (view, schema) + DatasourceName(view, schema, catalog) for view in self.db_engine_spec.get_view_names( database=self, inspector=inspector, @@ -792,22 +808,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable key="db:{self.id}:schema_list", cache=cache_manager.cache, ) - def get_all_schema_names( # pylint: disable=unused-argument + def get_all_schema_names( self, + *, catalog: str | None = None, - cache: bool = False, - cache_timeout: int | None = None, - force: bool = False, ssh_tunnel: SSHTunnel | None = None, - ) -> list[str]: - """Parameters need to be passed as keyword arguments. + ) -> set[str]: + """ + Return the schemas in a given database - For unused parameters, they are referenced in - cache_util.memoized_func decorator. - - :param cache: whether cache is enabled for the function - :param cache_timeout: timeout in seconds for the cache - :param force: whether to force refresh the cache + :param catalog: override default catalog + :param ssh_tunnel: SSH tunnel information needed to establish a connection :return: schema list """ try: @@ -819,6 +830,27 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex + @cache_util.memoized_func( + key="db:{self.id}:catalog_list", + cache=cache_manager.cache, + ) + def get_all_catalog_names( + self, + *, + ssh_tunnel: SSHTunnel | None = None, + ) -> set[str]: + """ + Return the catalogs in a given database + + :param ssh_tunnel: SSH tunnel information needed to establish a connection + :return: catalog list + """ + try: + with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector: + return self.db_engine_spec.get_catalog_names(self, inspector) + except Exception as ex: + raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex + @property def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]: url = make_url_safe(self.sqlalchemy_uri_decrypted) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index f77b3ab4d..100391086 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -209,7 +209,10 @@ class ImportExportMixin: """Get a mapping of foreign name to the local name of foreign keys""" parent_rel = cls.__mapper__.relationships.get(cls.export_parent) if parent_rel: - return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs} # noqa: E741 + return { + local.name: remote.name + for (local, remote) in parent_rel.local_remote_pairs + } return {} @classmethod @@ -772,6 +775,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def database(self) -> "Database": raise NotImplementedError() + @property + def catalog(self) -> str: + raise NotImplementedError() + @property def schema(self) -> str: raise NotImplementedError() @@ -1025,7 +1032,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods return df try: - df = self.database.get_df(sql, self.schema, mutator=assign_column_label) + df = self.database.get_df( + sql, + self.catalog, + self.schema, + mutator=assign_column_label, + ) except Exception as ex: # pylint: disable=broad-except df = pd.DataFrame() status = QueryStatus.FAILED diff --git a/superset/models/slice.py b/superset/models/slice.py index 197b09a9e..2a0734b10 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -84,6 +84,7 @@ class Slice( # pylint: disable=too-many-public-methods cache_timeout = Column(Integer) perm = Column(String(1000)) schema_perm = Column(String(1000)) + catalog_perm = Column(String(1000), nullable=True, default=None) # the last time a user has saved the chart, changed_on is referencing # when the database row was last written last_saved_at = Column(DateTime, nullable=True) diff --git a/superset/security/manager.py b/superset/security/manager.py index 790301726..d28ed1789 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -22,7 +22,7 @@ import logging import re import time from collections import defaultdict -from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING from flask import current_app, Flask, g, Request from flask_appbuilder import Model @@ -67,7 +67,7 @@ from superset.security.guest_token import ( GuestTokenUser, GuestUser, ) -from superset.sql_parse import extract_tables_from_jinja_sql +from superset.sql_parse import extract_tables_from_jinja_sql, Table from superset.superset_typing import Metric from superset.utils.core import ( DatasourceName, @@ -89,7 +89,6 @@ if TYPE_CHECKING: from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import Query - from superset.sql_parse import Table from superset.viz import BaseViz logger = logging.getLogger(__name__) @@ -97,8 +96,9 @@ logger = logging.getLogger(__name__) DATABASE_PERM_REGEX = re.compile(r"^\[.+\]\.\(id\:(?P\d+)\)$") -class DatabaseAndSchema(NamedTuple): +class DatabaseCatalogSchema(NamedTuple): database: str + catalog: Optional[str] schema: str @@ -346,17 +346,51 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return self.get_guest_user_from_request(request) return None + def get_catalog_perm( + self, + database: str, + catalog: Optional[str] = None, + ) -> Optional[str]: + """ + Return the database specific catalog permission. + + :param database: The Superset database or database name + :param catalog: The database catalog name + :return: The database specific schema permission + """ + if catalog is None: + return None + + return f"[{database}].[{catalog}]" + def get_schema_perm( - self, database: Union["Database", str], schema: Optional[str] = None + self, + database: str, + catalog: Optional[str] = None, + schema: Optional[str] = None, ) -> Optional[str]: """ Return the database specific schema permission. - :param database: The Superset database or database name - :param schema: The Superset schema name + Catalogs were added in SIP-95, and not all databases support them. Because of + this, the format used for permissions is different depending on whether a + catalog is passed or not: + + [database].[schema] + [database].[catalog].[schema] + + :param database: The database name + :param catalog: The database catalog name + :param schema: The database schema name :return: The database specific schema permission """ - return f"[{database}].[{schema}]" if schema else None + if schema is None: + return None + + if catalog: + return f"[{database}].[{catalog}].[{schema}]" + + return f"[{database}].[{schema}]" @staticmethod def get_database_perm(database_id: int, database_name: str) -> Optional[str]: @@ -370,13 +404,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) -> Optional[str]: return f"[{database_name}].[{dataset_name}](id:{dataset_id})" - def unpack_database_and_schema(self, schema_permission: str) -> DatabaseAndSchema: - # [database_name].[schema|table] - - schema_name = schema_permission.split(".")[1][1:-1] - database_name = schema_permission.split(".")[0][1:-1] - return DatabaseAndSchema(database_name, schema_name) - def can_access(self, permission_name: str, view_name: str) -> bool: """ Return True if the user can access the FAB permission/view, False otherwise. @@ -436,6 +463,17 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods or self.can_access("database_access", database.perm) # type: ignore ) + def can_access_catalog(self, database: "Database", catalog: str) -> bool: + """ + Return if the user can access the specified catalog. + """ + catalog_perm = self.get_catalog_perm(database.database_name, catalog) + return bool( + self.can_access_all_datasources() + or self.can_access_database(database) + or (catalog_perm and self.can_access("catalog_access", catalog_perm)) + ) + def can_access_schema(self, datasource: "BaseDatasource") -> bool: """ Return True if the user can access the schema associated with specified @@ -448,6 +486,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return ( self.can_access_all_datasources() or self.can_access_database(datasource.database) + or self.can_access_catalog(datasource.database, datasource.catalog) or self.can_access("schema_access", datasource.schema_perm or "") ) @@ -706,55 +745,150 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ] def get_schemas_accessible_by_user( - self, database: "Database", schemas: list[str], hierarchical: bool = True - ) -> list[str]: + self, + database: "Database", + catalog: Optional[str], + schemas: set[str], + hierarchical: bool = True, + ) -> set[str]: """ - Return the list of SQL schemas accessible by the user. + Returned a filtered list of the schemas accessible by the user. + + If not catalog is specified, the default catalog is used. :param database: The SQL database - :param schemas: The list of eligible SQL schemas + :param catalog: An optional database catalog + :param schemas: A set of candidate schemas :param hierarchical: Whether to check using the hierarchical permission logic - :returns: The list of accessible SQL schemas + :returns: The set of accessible database schemas """ # pylint: disable=import-outside-toplevel from superset.connectors.sqla.models import SqlaTable - if hierarchical and self.can_access_database(database): + if hierarchical and ( + self.can_access_database(database) + or (catalog and self.can_access_catalog(database, catalog)) + ): return schemas # schema_access - accessible_schemas = { - self.unpack_database_and_schema(s).schema - for s in self.user_view_menu_names("schema_access") - if s.startswith(f"[{database}].") - } + accessible_schemas: set[str] = set() + schema_access = self.user_view_menu_names("schema_access") + default_catalog = database.get_default_catalog() + default_schema = database.get_default_schema(default_catalog) + + for perm in schema_access: + parts = [part[1:-1] for part in perm.split(".")] + + if parts[0] != database.database_name: + continue + + # [database].[schema] matches when no catalog is specified, or when the user + # specifies the default catalog + if len(parts) == 2 and (catalog is None or catalog == default_catalog): + accessible_schemas.add(parts[1]) + + # [database].[catalog].[schema] matches when the catalog is equal to the + # requested catalog or, when no catalog specified, it's equal to the default + # catalog. + elif len(parts) == 3 and parts[1] == (catalog or default_catalog): + accessible_schemas.add(parts[2]) # datasource_access if perms := self.user_view_menu_names("datasource_access"): tables = ( self.get_session.query(SqlaTable.schema) .filter(SqlaTable.database_id == database.id) - .filter(SqlaTable.schema.isnot(None)) - .filter(SqlaTable.schema != "") .filter(or_(SqlaTable.perm.in_(perms))) .distinct() ) - accessible_schemas.update([table.schema for table in tables]) + accessible_schemas.update( + { + table.schema or default_schema # type: ignore + for table in tables + if (table.schema or default_schema) + } + ) - return [s for s in schemas if s in accessible_schemas] + return schemas & accessible_schemas + + def get_catalogs_accessible_by_user( + self, + database: "Database", + catalogs: set[str], + hierarchical: bool = True, + ) -> set[str]: + """ + Returned a filtered list of the catalogs accessible by the user. + + :param database: The SQL database + :param catalogs: A set of candidate catalogs + :param hierarchical: Whether to check using the hierarchical permission logic + :returns: The set of accessible database catalogs + """ + # pylint: disable=import-outside-toplevel + from superset.connectors.sqla.models import SqlaTable + + if hierarchical and self.can_access_database(database): + return catalogs + + # catalog access + accessible_catalogs: set[str] = set() + catalog_access = self.user_view_menu_names("catalog_access") + default_catalog = database.get_default_catalog() + + for perm in catalog_access: + parts = [part[1:-1] for part in perm.split(".")] + if parts[0] == database.database_name: + accessible_catalogs.add(parts[1]) + + # schema access + schema_access = self.user_view_menu_names("schema_access") + for perm in schema_access: + parts = [part[1:-1] for part in perm.split(".")] + + if parts[0] != database.database_name: + continue + if len(parts) == 2 and default_catalog: + accessible_catalogs.add(default_catalog) + elif len(parts) == 3: + accessible_catalogs.add(parts[2]) + + # datasource_access + if perms := self.user_view_menu_names("datasource_access"): + tables = ( + self.get_session.query(SqlaTable.schema) + .filter(SqlaTable.database_id == database.id) + .filter(or_(SqlaTable.perm.in_(perms))) + .distinct() + ) + accessible_catalogs.update( + { + table.catalog or default_catalog # type: ignore + for table in tables + if (table.catalog or default_catalog) + } + ) + + return catalogs & accessible_catalogs def get_datasources_accessible_by_user( # pylint: disable=invalid-name self, database: "Database", datasource_names: list[DatasourceName], + catalog: Optional[str] = None, schema: Optional[str] = None, ) -> list[DatasourceName]: """ - Return the list of SQL tables accessible by the user. + Filter list of SQL tables to the ones accessible by the user. + + When catalog and/or schema are specified, it's assumed that all datasources in + `datasource_names` are in the given catalog/schema. :param database: The SQL database :param datasource_names: The list of eligible SQL tables w/ schema + :param catalog: The fallback SQL catalog if not present in the table name :param schema: The fallback SQL schema if not present in the table name :returns: The list of accessible SQL tables w/ schema """ @@ -764,22 +898,34 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if self.can_access_database(database): return datasource_names + if catalog: + catalog_perm = self.get_catalog_perm(database.database_name, catalog) + if catalog_perm and self.can_access("catalog_access", catalog_perm): + return datasource_names + if schema: - schema_perm = self.get_schema_perm(database, schema) + schema_perm = self.get_schema_perm(database, catalog, schema) if schema_perm and self.can_access("schema_access", schema_perm): return datasource_names user_perms = self.user_view_menu_names("datasource_access") + catalog_perms = self.user_view_menu_names("catalog_access") schema_perms = self.user_view_menu_names("schema_access") - user_datasources = SqlaTable.query_datasources_by_permissions( - database, user_perms, schema_perms - ) - if schema: - names = {d.table_name for d in user_datasources if d.schema == schema} - return [d for d in datasource_names if d.table in names] + user_datasources = { + DatasourceName(table.table_name, table.schema, table.catalog) + for table in SqlaTable.query_datasources_by_permissions( + database, + user_perms, + catalog_perms, + schema_perms, + ) + } - full_names = {d.full_name for d in user_datasources} - return [d for d in datasource_names if f"[{database}].[{d}]" in full_names] + return [ + datasource + for datasource in datasource_names + if datasource in user_datasources + ] def merge_perm(self, permission_name: str, view_menu_name: str) -> None: """ @@ -844,6 +990,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods for datasource in datasources: merge_pv("datasource_access", datasource.get_perm()) merge_pv("schema_access", datasource.get_schema_perm()) + merge_pv("catalog_access", datasource.get_catalog_perm()) logger.info("Creating missing database permissions.") databases = self.get_session.query(models.Database).all() @@ -1212,7 +1359,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods self.get_session.query(self.permissionview_model) .join(self.permission_model) .join(self.viewmenu_model) - .filter(self.permission_model.name == "schema_access") + .filter( + or_( + self.permission_model.name == "schema_access", + self.permission_model.name == "catalog_access", + ) + ) .filter(self.viewmenu_model.name.like(f"[{database_name}].[%]")) .all() ) @@ -1399,18 +1551,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods .values(perm=dataset_perm) ) + # update catalog and schema perms + values: dict[str, Optional[str]] = {} + if target.schema: dataset_schema_perm = self.get_schema_perm( - database.database_name, target.schema + database.database_name, + target.catalog, + target.schema, ) self._insert_pvm_on_sqla_event( - mapper, connection, "schema_access", dataset_schema_perm + mapper, + connection, + "schema_access", + dataset_schema_perm, ) target.schema_perm = dataset_schema_perm + values["schema_perm"] = dataset_schema_perm + + if target.catalog: + dataset_catalog_perm = self.get_catalog_perm( + database.database_name, + target.catalog, + ) + self._insert_pvm_on_sqla_event( + mapper, + connection, + "catalog_access", + dataset_catalog_perm, + ) + target.catalog_perm = dataset_catalog_perm + values["catalog_perm"] = dataset_catalog_perm + + if values: connection.execute( dataset_table.update() .where(dataset_table.c.id == target.id) - .values(schema_perm=dataset_schema_perm) + .values(**values) ) def dataset_after_delete( @@ -1467,6 +1644,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods table.select().where(table.c.id == target.id) ).one() current_db_id = current_dataset.database_id + current_catalog = current_dataset.catalog current_schema = current_dataset.schema current_table_name = current_dataset.table_name @@ -1479,14 +1657,21 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods mapper, connection, target.perm, new_dataset_vm_name, target ) - # Updates schema permissions - new_dataset_schema_name = self.get_schema_perm( - target.database.database_name, target.schema + # Updates catalog/schema permissions + dataset_catalog_name = self.get_catalog_perm( + target.database.database_name, + target.catalog, ) - self._update_dataset_schema_perm( + dataset_schema_name = self.get_schema_perm( + target.database.database_name, + target.catalog, + target.schema, + ) + self._update_dataset_catalog_schema_perm( mapper, connection, - new_dataset_schema_name, + dataset_catalog_name, + dataset_schema_name, target, ) @@ -1502,23 +1687,32 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods mapper, connection, old_dataset_vm_name, new_dataset_vm_name, target ) - # When schema changes - if current_schema != target.schema: - new_dataset_schema_name = self.get_schema_perm( - target.database.database_name, target.schema + # When catalog/schema change + if current_catalog != target.catalog or current_schema != target.schema: + dataset_catalog_name = self.get_catalog_perm( + target.database.database_name, + target.catalog, ) - self._update_dataset_schema_perm( + dataset_schema_name = self.get_schema_perm( + target.database.database_name, + target.catalog, + target.schema, + ) + self._update_dataset_catalog_schema_perm( mapper, connection, - new_dataset_schema_name, + dataset_catalog_name, + dataset_schema_name, target, ) - def _update_dataset_schema_perm( + # pylint: disable=invalid-name, too-many-arguments + def _update_dataset_catalog_schema_perm( self, mapper: Mapper, connection: Connection, - new_schema_permission_name: Optional[str], + catalog_permission_name: Optional[str], + schema_permission_name: Optional[str], target: "SqlaTable", ) -> None: """ @@ -1530,11 +1724,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :param mapper: The SQLA event mapper :param connection: The SQLA connection - :param new_schema_permission_name: The new schema permission name that changed + :param catalog_permission_name: The new catalog permission name that changed + :param schema_permission_name: The new schema permission name that changed :param target: Dataset that was updated :return: """ - logger.info("Updating schema perm, new: %s", new_schema_permission_name) from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel SqlaTable, ) @@ -1545,18 +1739,30 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods sqlatable_table = SqlaTable.__table__ # pylint: disable=no-member chart_table = Slice.__table__ # pylint: disable=no-member - # insert new schema PVM if it does not exist + # insert new PVMs if they don't not exist self._insert_pvm_on_sqla_event( - mapper, connection, "schema_access", new_schema_permission_name + mapper, + connection, + "catalog_access", + catalog_permission_name, + ) + self._insert_pvm_on_sqla_event( + mapper, + connection, + "schema_access", + schema_permission_name, ) - # Update dataset (SqlaTable schema_perm field) + # Update dataset connection.execute( sqlatable_table.update() .where( sqlatable_table.c.id == target.id, ) - .values(schema_perm=new_schema_permission_name) + .values( + catalog_perm=catalog_permission_name, + schema_perm=schema_permission_name, + ) ) # Update charts (Slice schema_perm field) @@ -1566,7 +1772,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods chart_table.c.datasource_id == target.id, chart_table.c.datasource_type == DatasourceType.TABLE, ) - .values(schema_perm=new_schema_permission_name) + .values( + catalog_perm=catalog_permission_name, + schema_perm=schema_permission_name, + ) ) def _update_dataset_perm( # pylint: disable=too-many-arguments @@ -1923,7 +2132,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods table: Optional["Table"] = None, viz: Optional["BaseViz"] = None, sql: Optional[str] = None, - catalog: Optional[str] = None, # pylint: disable=unused-argument + catalog: Optional[str] = None, schema: Optional[str] = None, ) -> None: """ @@ -1947,7 +2156,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import Query - from superset.sql_parse import Table from superset.utils.core import shortid if sql and database: @@ -1955,6 +2163,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods database=database, sql=sql, schema=schema, + catalog=catalog, client_id=shortid()[:10], user_id=get_user_id(), ) @@ -1970,9 +2179,23 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return if query: + # Getting the default schema for a query is hard. Users can select the + # schema in SQL Lab, but there's no guarantee that the query actually + # will run in that schema. Each DB engine spec needs to implement the + # necessary logic to enforce that the query runs in the selected schema. + # If the DB engine spec doesn't implement the logic the schema is read + # from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy + # inspector to read it. default_schema = database.get_default_schema_for_query(query) + # Determining the default catalog is much easier, because DB engine + # specs need explicit support for catalogs. + default_catalog = database.get_default_catalog() tables = { - Table(table_.table, table_.schema or default_schema) + Table( + table_.table, + table_.schema or default_schema, + table_.catalog or default_catalog, + ) for table_ in extract_tables_from_jinja_sql(query.sql, database) } elif table: @@ -1981,21 +2204,36 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods denied = set() for table_ in tables: - schema_perm = self.get_schema_perm(database, schema=table_.schema) + catalog_perm = self.get_catalog_perm( + database.database_name, + table_.catalog, + ) + if catalog_perm and self.can_access("catalog_access", catalog_perm): + continue - if not (schema_perm and self.can_access("schema_access", schema_perm)): - datasources = SqlaTable.query_datasources_by_name( - database, table_.table, schema=table_.schema - ) + schema_perm = self.get_schema_perm( + database, + table_.catalog, + table_.schema, + ) + if schema_perm and self.can_access("schema_access", schema_perm): + continue - # Access to any datasource is suffice. - for datasource_ in datasources: - if self.can_access( - "datasource_access", datasource_.perm - ) or self.is_owner(datasource_): - break - else: - denied.add(table_) + datasources = SqlaTable.query_datasources_by_name( + database, + table_.table, + schema=table_.schema, + catalog=table_.catalog, + ) + for datasource_ in datasources: + if self.can_access( + "datasource_access", + datasource_.perm, + ) or self.is_owner(datasource_): + # access to any datasource is sufficient + break + else: + denied.add(table_) if denied: raise SupersetSecurityException( diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 48e283e7c..00216fc4b 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -119,7 +119,11 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[..., def wrap(f: Callable[..., Any]) -> Callable[..., Any]: def wrapped_f(*args: Any, **kwargs: Any) -> Any: - if not kwargs.get("cache", True): + should_cache = kwargs.pop("cache", True) + force = kwargs.pop("force", False) + cache_timeout = kwargs.pop("cache_timeout", 0) + + if not should_cache: return f(*args, **kwargs) # format the key using args/kwargs passed to the decorated function @@ -129,10 +133,10 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[..., cache_key = key.format(**bound_args.arguments) obj = cache.get(cache_key) - if not kwargs.get("force") and obj is not None: + if not force and obj is not None: return obj obj = f(*args, **kwargs) - cache.set(cache_key, obj, timeout=kwargs.get("cache_timeout", 0)) + cache.set(cache_key, obj, timeout=cache_timeout) return obj return wrapped_f diff --git a/superset/utils/core.py b/superset/utils/core.py index f02b00443..514cbac75 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -679,11 +679,13 @@ def generic_find_uq_constraint_name( def get_datasource_full_name( - database_name: str, datasource_name: str, schema: str | None = None + database_name: str, + datasource_name: str, + catalog: str | None = None, + schema: str | None = None, ) -> str: - if not schema: - return f"[{database_name}].[{datasource_name}]" - return f"[{database_name}].[{schema}].[{datasource_name}]" + parts = [database_name, catalog, schema, datasource_name] + return ".".join([f"[{part}]" for part in parts if part]) def validate_json(obj: bytes | bytearray | str) -> None: @@ -1051,8 +1053,8 @@ def merge_extra_form_data(form_data: dict[str, Any]) -> None: "adhoc_filters", [] ) adhoc_filters.extend( - {"isExtra": True, **fltr} # type: ignore - for fltr in append_adhoc_filters + {"isExtra": True, **adhoc_filter} # type: ignore + for adhoc_filter in append_adhoc_filters ) if append_filters: for key, value in form_data.items(): @@ -1502,6 +1504,7 @@ def shortid() -> str: class DatasourceName(NamedTuple): table: str schema: str + catalog: str | None = None def get_stacktrace() -> str | None: diff --git a/superset/utils/filters.py b/superset/utils/filters.py index 88154a40b..8c4a07994 100644 --- a/superset/utils/filters.py +++ b/superset/utils/filters.py @@ -32,10 +32,12 @@ def get_dataset_access_filters( database_ids = security_manager.get_accessible_databases() perms = security_manager.user_view_menu_names("datasource_access") schema_perms = security_manager.user_view_menu_names("schema_access") + catalog_perms = security_manager.user_view_menu_names("catalog_access") return or_( Database.id.in_(database_ids), base_model.perm.in_(perms), + base_model.catalog_perm.in_(catalog_perms), base_model.schema_perm.in_(schema_perms), *args, ) diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py index c6e799e6d..0d104aad5 100644 --- a/superset/views/database/mixins.py +++ b/superset/views/database/mixins.py @@ -211,11 +211,29 @@ class DatabaseMixin: utils.parse_ssl_cert(database.server_cert) database.set_sqlalchemy_uri(database.sqlalchemy_uri) security_manager.add_permission_view_menu("database_access", database.perm) - # adding a new database we always want to force refresh schema list - for schema in database.get_all_schema_names(): - security_manager.add_permission_view_menu( - "schema_access", security_manager.get_schema_perm(database, schema) - ) + + # add catalog/schema permissions + if database.db_engine_spec.supports_catalog: + catalogs = database.get_all_catalog_names() + for catalog in catalogs: + security_manager.add_permission_view_menu( + "catalog_access", + security_manager.get_catalog_perm(database.database_name, catalog), + ) + else: + # add a dummy catalog for DBs that don't support them + catalogs = [None] + + for catalog in catalogs: + for schema in database.get_all_schema_names(catalog=catalog): + security_manager.add_permission_view_menu( + "schema_access", + security_manager.get_schema_perm( + database.database_name, + catalog, + schema, + ), + ) def pre_add(self, database: Database) -> None: self._pre_add_update(database) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 9e0ee97b1..830625a15 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -36,7 +36,6 @@ from sqlalchemy.exc import DBAPIError from sqlalchemy.sql import func from superset import db, security_manager -from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError # noqa: F401 from superset.connectors.sqla.models import SqlaTable from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe # noqa: F401 @@ -296,14 +295,14 @@ class TestDatabaseApi(SupersetTestCase): "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch("superset.commands.database.create.is_feature_enabled") - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") def test_create_database_with_ssh_tunnel( self, - mock_test_connection_database_command_run, - mock_create_is_feature_enabled, mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, ): """ Database API: Test create with SSH Tunnel @@ -344,14 +343,14 @@ class TestDatabaseApi(SupersetTestCase): "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch("superset.commands.database.create.is_feature_enabled") - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") def test_create_database_with_missing_port_raises_error( self, - mock_test_connection_database_command_run, - mock_create_is_feature_enabled, mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, ): """ Database API: Test that missing port raises SSHTunnelDatabaseError @@ -397,15 +396,15 @@ class TestDatabaseApi(SupersetTestCase): ) @mock.patch("superset.commands.database.create.is_feature_enabled") @mock.patch("superset.commands.database.update.is_feature_enabled") - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") def test_update_database_with_ssh_tunnel( self, - mock_test_connection_database_command_run, - mock_create_is_feature_enabled, - mock_update_is_feature_enabled, mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_update_is_feature_enabled, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, ): """ Database API: Test update Database with SSH Tunnel @@ -458,15 +457,15 @@ class TestDatabaseApi(SupersetTestCase): ) @mock.patch("superset.commands.database.create.is_feature_enabled") @mock.patch("superset.commands.database.update.is_feature_enabled") - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") def test_update_database_with_missing_port_raises_error( self, - mock_test_connection_database_command_run, - mock_create_is_feature_enabled, - mock_update_is_feature_enabled, mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_update_is_feature_enabled, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, ): """ Database API: Test that missing port raises SSHTunnelDatabaseError @@ -523,16 +522,16 @@ class TestDatabaseApi(SupersetTestCase): @mock.patch("superset.commands.database.create.is_feature_enabled") @mock.patch("superset.commands.database.update.is_feature_enabled") @mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled") - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") def test_delete_ssh_tunnel( self, - mock_test_connection_database_command_run, - mock_create_is_feature_enabled, - mock_update_is_feature_enabled, - mock_delete_is_feature_enabled, mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_delete_is_feature_enabled, + mock_update_is_feature_enabled, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, ): """ Database API: Test deleting a SSH tunnel via Database update @@ -606,15 +605,15 @@ class TestDatabaseApi(SupersetTestCase): ) @mock.patch("superset.commands.database.create.is_feature_enabled") @mock.patch("superset.commands.database.update.is_feature_enabled") - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") def test_update_ssh_tunnel_via_database_api( self, - mock_test_connection_database_command_run, - mock_create_is_feature_enabled, - mock_update_is_feature_enabled, mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_update_is_feature_enabled, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, ): """ Database API: Test update SSH Tunnel via Database API @@ -684,15 +683,15 @@ class TestDatabaseApi(SupersetTestCase): @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") @mock.patch("superset.commands.database.create.is_feature_enabled") def test_cascade_delete_ssh_tunnel( self, - mock_test_connection_database_command_run, - mock_get_all_schema_names, mock_create_is_feature_enabled, + mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_test_connection_database_command_run, ): """ Database API: SSH Tunnel gets deleted if Database gets deleted @@ -739,16 +738,16 @@ class TestDatabaseApi(SupersetTestCase): "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch("superset.commands.database.create.is_feature_enabled") - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") @mock.patch("superset.extensions.db.session.rollback") def test_do_not_create_database_if_ssh_tunnel_creation_fails( self, - mock_rollback, - mock_test_connection_database_command_run, - mock_create_is_feature_enabled, mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, + mock_rollback, ): """ Database API: Test rollback is called if SSH Tunnel creation fails @@ -788,14 +787,14 @@ class TestDatabaseApi(SupersetTestCase): "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch("superset.commands.database.create.is_feature_enabled") - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") def test_get_database_returns_related_ssh_tunnel( self, - mock_test_connection_database_command_run, - mock_create_is_feature_enabled, mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, ): """ Database API: Test GET Database returns its related SSH Tunnel @@ -842,13 +841,13 @@ class TestDatabaseApi(SupersetTestCase): @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", ) - @mock.patch( - "superset.models.core.Database.get_all_schema_names", - ) + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception( self, - mock_test_connection_database_command_run, mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_test_connection_database_command_run, ): """ Database API: Test raises SSHTunneling feature flag not enabled @@ -1987,13 +1986,13 @@ class TestDatabaseApi(SupersetTestCase): rv = self.client.get(f"api/v1/database/{database.id}/schemas/") response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(schemas, response["result"]) + self.assertEqual(schemas, set(response["result"])) rv = self.client.get( f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}" ) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(schemas, response["result"]) + self.assertEqual(schemas, set(response["result"])) def test_database_schemas_not_found(self): """ diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 2388b38ff..8979b91c4 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -1095,7 +1095,7 @@ class TestTablesDatabaseCommand(SupersetTestCase): @patch("superset.daos.database.DatabaseDAO.find_by_id") def test_database_tables_list_with_unknown_database(self, mock_find_by_id): mock_find_by_id.return_value = None - command = TablesDatabaseCommand(1, "test", False) + command = TablesDatabaseCommand(1, None, "test", False) with pytest.raises(DatabaseNotFoundError) as excinfo: command.run() @@ -1115,7 +1115,7 @@ class TestTablesDatabaseCommand(SupersetTestCase): mock_can_access_database.side_effect = SupersetException("Test Error") mock_g.user = security_manager.find_user("admin") - command = TablesDatabaseCommand(database.id, "main", False) + command = TablesDatabaseCommand(database.id, None, "main", False) with pytest.raises(SupersetException) as excinfo: command.run() assert str(excinfo.value) == "Test Error" @@ -1131,7 +1131,7 @@ class TestTablesDatabaseCommand(SupersetTestCase): mock_can_access_database.side_effect = Exception("Test Error") mock_g.user = security_manager.find_user("admin") - command = TablesDatabaseCommand(database.id, "main", False) + command = TablesDatabaseCommand(database.id, None, "main", False) with pytest.raises(DatabaseTablesUnexpectedError) as excinfo: command.run() assert ( @@ -1154,7 +1154,7 @@ class TestTablesDatabaseCommand(SupersetTestCase): if database.backend == "postgresql" or database.backend == "mysql": return - command = TablesDatabaseCommand(database.id, schema_name, False) + command = TablesDatabaseCommand(database.id, None, schema_name, False) result = command.run() assert result["count"] > 0 diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 708b94987..f21dbf54a 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -531,7 +531,7 @@ def test_get_catalog_names(app_context: AppContext) -> None: return with database.get_inspector() as inspector: - assert PostgresEngineSpec.get_catalog_names(database, inspector) == [ + assert PostgresEngineSpec.get_catalog_names(database, inspector) == { "postgres", "superset", - ] + } diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index b68cb7c05..4dc15a2af 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -343,20 +343,20 @@ class TestDatabaseModel(SupersetTestCase): main_db = get_example_database() if main_db.backend == "mysql": - df = main_db.get_df("SELECT 1", None) + df = main_db.get_df("SELECT 1", None, None) self.assertEqual(df.iat[0, 0], 1) - df = main_db.get_df("SELECT 1;", None) + df = main_db.get_df("SELECT 1;", None, None) self.assertEqual(df.iat[0, 0], 1) def test_multi_statement(self): main_db = get_example_database() if main_db.backend == "mysql": - df = main_db.get_df("USE superset; SELECT 1", None) + df = main_db.get_df("USE superset; SELECT 1", None, None) self.assertEqual(df.iat[0, 0], 1) - df = main_db.get_df("USE superset; SELECT ';';", None) + df = main_db.get_df("USE superset; SELECT ';';", None, None) self.assertEqual(df.iat[0, 0], ";") @mock.patch("superset.models.core.create_engine") diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index f22847ca5..5eca7b4a6 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -898,7 +898,7 @@ class TestRolePermission(SupersetTestCase): db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) self.assertEqual(changed_table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})") - self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541 + self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541 # Test Chart permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() @@ -956,12 +956,12 @@ class TestRolePermission(SupersetTestCase): db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one() ) self.assertEqual(changed_table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual(changed_table1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # noqa: F541 + self.assertEqual(changed_table1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541 # Test Chart schema permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})") - self.assertEqual(slice1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # noqa: F541 + self.assertEqual(slice1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541 # cleanup db.session.delete(slice1) @@ -1069,7 +1069,7 @@ class TestRolePermission(SupersetTestCase): self.assertEqual( changed_table1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})" ) - self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541 + self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541 # Test Chart permission changed slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one() @@ -1158,9 +1158,9 @@ class TestRolePermission(SupersetTestCase): with self.client.application.test_request_context(): database = get_example_database() schemas = security_manager.get_schemas_accessible_by_user( - database, ["1", "2", "3"] + database, None, {"1", "2", "3"} ) - self.assertEqual(schemas, ["1", "2", "3"]) # no changes + self.assertEqual(schemas, {"1", "2", "3"}) # no changes @patch("superset.utils.core.g") @patch("superset.security.manager.g") @@ -1171,10 +1171,10 @@ class TestRolePermission(SupersetTestCase): with self.client.application.test_request_context(): database = get_example_database() schemas = security_manager.get_schemas_accessible_by_user( - database, ["1", "2", "3"] + database, None, {"1", "2", "3"} ) # temp_schema is not passed in the params - self.assertEqual(schemas, ["1"]) + self.assertEqual(schemas, {"1"}) delete_schema_perm("[examples].[1]") def test_schemas_accessible_by_user_datasource_access(self): @@ -1183,9 +1183,9 @@ class TestRolePermission(SupersetTestCase): with self.client.application.test_request_context(): with override_user(security_manager.find_user("gamma")): schemas = security_manager.get_schemas_accessible_by_user( - database, ["temp_schema", "2", "3"] + database, None, {"temp_schema", "2", "3"} ) - self.assertEqual(schemas, ["temp_schema"]) + self.assertEqual(schemas, {"temp_schema"}) def test_schemas_accessible_by_user_datasource_and_schema_access(self): # User has schema access to the datasource temp_schema.wb_health_population in examples DB. @@ -1194,9 +1194,9 @@ class TestRolePermission(SupersetTestCase): database = get_example_database() with override_user(security_manager.find_user("gamma")): schemas = security_manager.get_schemas_accessible_by_user( - database, ["temp_schema", "2", "3"] + database, None, {"temp_schema", "2", "3"} ) - self.assertEqual(schemas, ["temp_schema", "2"]) + self.assertEqual(schemas, {"temp_schema", "2"}) vm = security_manager.find_permission_view_menu( "schema_access", "[examples].[2]" ) diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 942653de1..96019c16c 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -284,9 +284,15 @@ class TestSqlLab(SupersetTestCase): # sqlite doesn't support database creation return + catalog = examples_db.get_default_catalog() sqllab_test_db_schema_permission_view = ( security_manager.add_permission_view_menu( - "schema_access", f"[{examples_db.name}].[{CTAS_SCHEMA_NAME}]" + "schema_access", + security_manager.get_schema_perm( + examples_db.name, + catalog, + CTAS_SCHEMA_NAME, + ), ) ) schema_perm_role = security_manager.add_role("SchemaPermission") diff --git a/tests/unit_tests/commands/databases/create_test.py b/tests/unit_tests/commands/databases/create_test.py new file mode 100644 index 000000000..405238827 --- /dev/null +++ b/tests/unit_tests/commands/databases/create_test.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from superset.commands.database.create import CreateDatabaseCommand +from superset.extensions import security_manager + + +@pytest.fixture() +def database_with_catalog(mocker: MockerFixture) -> MagicMock: + """ + Mock a database with catalogs and schemas. + """ + mocker.patch("superset.commands.database.create.db") + mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") + + database = mocker.MagicMock() + database.database_name = "test_database" + database.db_engine_spec.__name__ = "test_engine" + database.db_engine_spec.supports_catalog = True + database.get_all_catalog_names.return_value = ["catalog1", "catalog2"] + database.get_all_schema_names.side_effect = [ + {"schema1", "schema2"}, + {"schema3", "schema4"}, + ] + + DatabaseDAO = mocker.patch("superset.commands.database.create.DatabaseDAO") + DatabaseDAO.create.return_value = database + + return database + + +@pytest.fixture() +def database_without_catalog(mocker: MockerFixture) -> MagicMock: + """ + Mock a database without catalogs. + """ + mocker.patch("superset.commands.database.create.db") + mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") + + database = mocker.MagicMock() + database.database_name = "test_database" + database.db_engine_spec.__name__ = "test_engine" + database.db_engine_spec.supports_catalog = False + database.get_all_schema_names.return_value = ["schema1", "schema2"] + + DatabaseDAO = mocker.patch("superset.commands.database.create.DatabaseDAO") + DatabaseDAO.create.return_value = database + + return database + + +def test_create_permissions_with_catalog( + mocker: MockerFixture, + database_with_catalog: MockerFixture, +) -> None: + """ + Test that permissions are created when a database with a catalog is created. + """ + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + CreateDatabaseCommand( + { + "database_name": "test_database", + "sqlalchemy_uri": "sqlite://", + } + ).run() + + add_permission_view_menu.assert_has_calls( + [ + mocker.call("catalog_access", "[test_database].[catalog1]"), + mocker.call("catalog_access", "[test_database].[catalog2]"), + mocker.call("schema_access", "[test_database].[catalog1].[schema1]"), + mocker.call("schema_access", "[test_database].[catalog1].[schema2]"), + mocker.call("schema_access", "[test_database].[catalog2].[schema3]"), + mocker.call("schema_access", "[test_database].[catalog2].[schema4]"), + ], + any_order=True, + ) + + +def test_create_permissions_without_catalog( + mocker: MockerFixture, + database_without_catalog: MockerFixture, +) -> None: + """ + Test that permissions are created when a database without a catalog is created. + """ + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + CreateDatabaseCommand( + { + "database_name": "test_database", + "sqlalchemy_uri": "sqlite://", + } + ).run() + + add_permission_view_menu.assert_has_calls( + [ + mocker.call("schema_access", "[test_database].[schema1]"), + mocker.call("schema_access", "[test_database].[schema2]"), + ], + any_order=True, + ) diff --git a/tests/unit_tests/commands/databases/tables_test.py b/tests/unit_tests/commands/databases/tables_test.py new file mode 100644 index 000000000..d9a8583f9 --- /dev/null +++ b/tests/unit_tests/commands/databases/tables_test.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from superset.commands.database.tables import TablesDatabaseCommand +from superset.extensions import security_manager +from superset.utils.core import DatasourceName + + +@pytest.fixture() +def database_with_catalog(mocker: MockerFixture) -> MagicMock: + """ + Mock a database with catalogs and schemas. + """ + mocker.patch("superset.commands.database.tables.db") + + database = mocker.MagicMock() + database.database_name = "test_database" + database.get_all_table_names_in_schema.return_value = [ + DatasourceName("table1", "schema1", "catalog1"), + DatasourceName("table2", "schema1", "catalog1"), + ] + database.get_all_view_names_in_schema.return_value = [ + DatasourceName("view1", "schema1", "catalog1"), + ] + + DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO") + DatabaseDAO.find_by_id.return_value = database + + return database + + +@pytest.fixture() +def database_without_catalog(mocker: MockerFixture) -> MagicMock: + """ + Mock a database without catalogs but with schemas. + """ + mocker.patch("superset.commands.database.tables.db") + + database = mocker.MagicMock() + database.database_name = "test_database" + database.get_all_table_names_in_schema.return_value = [ + DatasourceName("table1", "schema1"), + DatasourceName("table2", "schema1"), + ] + database.get_all_view_names_in_schema.return_value = [ + DatasourceName("view1", "schema1"), + ] + + DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO") + DatabaseDAO.find_by_id.return_value = database + + return database + + +def test_tables_with_catalog( + mocker: MockerFixture, + database_with_catalog: MockerFixture, +) -> None: + """ + Test that permissions are created when a database with a catalog is created. + """ + get_datasources_accessible_by_user = mocker.patch.object( + security_manager, + "get_datasources_accessible_by_user", + side_effect=[ + { + DatasourceName("table1", "schema1", "catalog1"), + DatasourceName("table2", "schema1", "catalog1"), + }, + {DatasourceName("view1", "schema1", "catalog1")}, + ], + ) + + db = mocker.patch("superset.commands.database.tables.db") + table = mocker.MagicMock() + table.name = "table1" + table.extra_dict = {"foo": "bar"} + db.session.query().filter().options().all.return_value = [table] + + payload = TablesDatabaseCommand(1, "catalog1", "schema1", False).run() + assert payload == { + "count": 3, + "result": [ + {"value": "table1", "type": "table", "extra": {"foo": "bar"}}, + {"value": "table2", "type": "table", "extra": None}, + {"value": "view1", "type": "view"}, + ], + } + + get_datasources_accessible_by_user.assert_has_calls( + [ + mocker.call( + database=database_with_catalog, + catalog="catalog1", + schema="schema1", + datasource_names=[ + DatasourceName("table1", "schema1", "catalog1"), + DatasourceName("table2", "schema1", "catalog1"), + ], + ), + mocker.call( + database=database_with_catalog, + catalog="catalog1", + schema="schema1", + datasource_names=[ + DatasourceName("view1", "schema1", "catalog1"), + ], + ), + ], + ) + + database_with_catalog.get_all_table_names_in_schema.assert_called_with( + catalog="catalog1", + schema="schema1", + force=False, + cache=database_with_catalog.table_cache_enabled, + cache_timeout=database_with_catalog.table_cache_timeout, + ) + + +def test_tables_without_catalog( + mocker: MockerFixture, + database_without_catalog: MockerFixture, +) -> None: + """ + Test that permissions are created when a database without a catalog is created. + """ + get_datasources_accessible_by_user = mocker.patch.object( + security_manager, + "get_datasources_accessible_by_user", + side_effect=[ + { + DatasourceName("table1", "schema1"), + DatasourceName("table2", "schema1"), + }, + {DatasourceName("view1", "schema1")}, + ], + ) + + db = mocker.patch("superset.commands.database.tables.db") + table = mocker.MagicMock() + table.name = "table1" + table.extra_dict = {"foo": "bar"} + db.session.query().filter().options().all.return_value = [table] + + payload = TablesDatabaseCommand(1, None, "schema1", False).run() + assert payload == { + "count": 3, + "result": [ + {"value": "table1", "type": "table", "extra": {"foo": "bar"}}, + {"value": "table2", "type": "table", "extra": None}, + {"value": "view1", "type": "view"}, + ], + } + + get_datasources_accessible_by_user.assert_has_calls( + [ + mocker.call( + database=database_without_catalog, + catalog=None, + schema="schema1", + datasource_names=[ + DatasourceName("table1", "schema1"), + DatasourceName("table2", "schema1"), + ], + ), + mocker.call( + database=database_without_catalog, + catalog=None, + schema="schema1", + datasource_names=[ + DatasourceName("view1", "schema1"), + ], + ), + ], + ) + + database_without_catalog.get_all_table_names_in_schema.assert_called_with( + catalog=None, + schema="schema1", + force=False, + cache=database_without_catalog.table_cache_enabled, + cache_timeout=database_without_catalog.table_cache_timeout, + ) diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py new file mode 100644 index 000000000..300efb62e --- /dev/null +++ b/tests/unit_tests/commands/databases/update_test.py @@ -0,0 +1,272 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from superset.commands.database.update import UpdateDatabaseCommand +from superset.extensions import security_manager + + +@pytest.fixture() +def database_with_catalog(mocker: MockerFixture) -> MagicMock: + """ + Mock a database with catalogs and schemas. + """ + mocker.patch("superset.commands.database.update.db") + + database = mocker.MagicMock() + database.database_name = "my_db" + database.db_engine_spec.__name__ = "test_engine" + database.db_engine_spec.supports_catalog = True + database.get_all_catalog_names.return_value = ["catalog1", "catalog2"] + database.get_all_schema_names.side_effect = [ + ["schema1", "schema2"], + ["schema3", "schema4"], + ] + database.get_default_catalog.return_value = "catalog2" + + return database + + +@pytest.fixture() +def database_without_catalog(mocker: MockerFixture) -> MagicMock: + """ + Mock a database without catalogs. + """ + mocker.patch("superset.commands.database.update.db") + + database = mocker.MagicMock() + database.database_name = "my_db" + database.db_engine_spec.__name__ = "test_engine" + database.db_engine_spec.supports_catalog = False + database.get_all_schema_names.return_value = ["schema1", "schema2"] + + return database + + +def test_update_with_catalog( + mocker: MockerFixture, + database_with_catalog: MockerFixture, +) -> None: + """ + Test that permissions are updated correctly. + + In this test, the database has two catalogs with two schemas each: + + - catalog1 + - schema1 + - schema2 + - catalog2 + - schema3 + - schema4 + + When update is called, only `catalog2.schema3` has permissions associated with it, + so `catalog1.*` and `catalog2.schema4` are added. + """ + DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") + DatabaseDAO.find_by_id.return_value = database_with_catalog + DatabaseDAO.update.return_value = database_with_catalog + + find_permission_view_menu = mocker.patch.object( + security_manager, + "find_permission_view_menu", + ) + find_permission_view_menu.side_effect = [ + None, # first catalog is new + "[my_db].[catalog2]", # second catalog already exists + "[my_db].[catalog2].[schema3]", # first schema already exists + None, # second schema is new + # these are called when checking for existing perms in [db].[schema] format + None, + None, + ] + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + UpdateDatabaseCommand(1, {}).run() + + add_permission_view_menu.assert_has_calls( + [ + # first catalog is added with all schemas + mocker.call("catalog_access", "[my_db].[catalog1]"), + mocker.call("schema_access", "[my_db].[catalog1].[schema1]"), + mocker.call("schema_access", "[my_db].[catalog1].[schema2]"), + # second catalog already exists, only `schema4` is added + mocker.call("schema_access", "[my_db].[catalog2].[schema4]"), + ], + ) + + +def test_update_without_catalog( + mocker: MockerFixture, + database_without_catalog: MockerFixture, +) -> None: + """ + Test that permissions are updated correctly. + + In this test, the database has no catalogs and two schemas: + + - schema1 + - schema2 + + When update is called, only `schema2` has permissions associated with it, so `schema1` + is added. + """ + DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") + DatabaseDAO.find_by_id.return_value = database_without_catalog + DatabaseDAO.update.return_value = database_without_catalog + + find_permission_view_menu = mocker.patch.object( + security_manager, + "find_permission_view_menu", + ) + find_permission_view_menu.side_effect = [ + None, # schema1 has no permissions + "[my_db].[schema2]", # second schema already exists + ] + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + UpdateDatabaseCommand(1, {}).run() + + add_permission_view_menu.assert_called_with( + "schema_access", + "[my_db].[schema1]", + ) + + +def test_rename_with_catalog( + mocker: MockerFixture, + database_with_catalog: MockerFixture, +) -> None: + """ + Test that permissions are renamed correctly. + + In this test, the database has two catalogs with two schemas each: + + - catalog1 + - schema1 + - schema2 + - catalog2 + - schema3 + - schema4 + + When update is called, only `catalog2.schema3` has permissions associated with it, + so `catalog1.*` and `catalog2.schema4` are added. Additionally, the database has + been renamed from `my_db` to `my_other_db`. + """ + DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") + original_database = mocker.MagicMock() + original_database.database_name = "my_db" + DatabaseDAO.find_by_id.return_value = original_database + database_with_catalog.database_name = "my_other_db" + DatabaseDAO.update.return_value = database_with_catalog + DatabaseDAO.get_datasets.return_value = [] + + find_permission_view_menu = mocker.patch.object( + security_manager, + "find_permission_view_menu", + ) + catalog2_pvm = mocker.MagicMock() + catalog2_schema3_pvm = mocker.MagicMock() + find_permission_view_menu.side_effect = [ + # these are called when adding the permissions: + None, # first catalog is new + "[my_db].[catalog2]", # second catalog already exists + "[my_db].[catalog2].[schema3]", # first schema already exists + None, # second schema is new + # these are called when renaming the permissions: + catalog2_pvm, # old [my_db].[catalog2] + catalog2_schema3_pvm, # old [my_db].[catalog2].[schema3] + None, # [my_db].[catalog2].[schema4] doesn't exist + ] + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + UpdateDatabaseCommand(1, {}).run() + + add_permission_view_menu.assert_has_calls( + [ + # first catalog is added with all schemas with the new DB name + mocker.call("catalog_access", "[my_other_db].[catalog1]"), + mocker.call("schema_access", "[my_other_db].[catalog1].[schema1]"), + mocker.call("schema_access", "[my_other_db].[catalog1].[schema2]"), + # second catalog already exists, only `schema4` is added + mocker.call("schema_access", "[my_other_db].[catalog2].[schema4]"), + ], + ) + + assert catalog2_pvm.view_menu.name == "[my_other_db].[catalog2]" + assert catalog2_schema3_pvm.view_menu.name == "[my_other_db].[catalog2].[schema3]" + + +def test_rename_without_catalog( + mocker: MockerFixture, + database_without_catalog: MockerFixture, +) -> None: + """ + Test that permissions are renamed correctly. + + In this test, the database has no catalogs and two schemas: + + - schema1 + - schema2 + + When update is called, only `schema2` has permissions associated with it, so `schema1` + is added. Additionally, the database has been renamed from `my_db` to `my_other_db`. + """ + DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO") + original_database = mocker.MagicMock() + original_database.database_name = "my_db" + DatabaseDAO.find_by_id.return_value = original_database + database_without_catalog.database_name = "my_other_db" + DatabaseDAO.update.return_value = database_without_catalog + DatabaseDAO.get_datasets.return_value = [] + + find_permission_view_menu = mocker.patch.object( + security_manager, + "find_permission_view_menu", + ) + schema2_pvm = mocker.MagicMock() + find_permission_view_menu.side_effect = [ + None, # schema1 has no permissions + "[my_db].[schema2]", # second schema already exists + None, # [my_db].[schema1] doesn't exist + schema2_pvm, # old [my_db].[schema2] + ] + add_permission_view_menu = mocker.patch.object( + security_manager, + "add_permission_view_menu", + ) + + UpdateDatabaseCommand(1, {}).run() + + add_permission_view_menu.assert_called_with( + "schema_access", + "[my_other_db].[schema1]", + ) + + assert schema2_pvm.view_menu.name == "[my_other_db].[schema2]" diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 71b3bff33..3905a15b3 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -88,6 +88,8 @@ def app(request: SubRequest) -> Iterator[SupersetApp]: app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False app.config["TESTING"] = True app.config["RATELIMIT_ENABLED"] = False + app.config["CACHE_CONFIG"] = {} + app.config["DATA_CACHE_CONFIG"] = {} # loop over extra configs passed in by tests # and update the app config diff --git a/tests/unit_tests/connectors/sqla/models_test.py b/tests/unit_tests/connectors/sqla/models_test.py index 00b4b0a31..687295baf 100644 --- a/tests/unit_tests/connectors/sqla/models_test.py +++ b/tests/unit_tests/connectors/sqla/models_test.py @@ -17,9 +17,11 @@ import pytest from pytest_mock import MockerFixture +from sqlalchemy import create_engine from superset.connectors.sqla.models import SqlaTable from superset.exceptions import OAuth2RedirectError +from superset.models.core import Database from superset.superset_typing import QueryObjectDict @@ -64,3 +66,124 @@ def test_query_bubbles_errors(mocker: MockerFixture) -> None: } with pytest.raises(OAuth2RedirectError): sqla_table.query(query_obj) + + +def test_permissions_without_catalog() -> None: + """ + Test permissions when the table has no catalog. + """ + database = Database(database_name="my_db") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=database, + schema="schema1", + catalog=None, + id=1, + ) + + assert sqla_table.get_perm() == "[my_db].[my_sqla_table](id:1)" + assert sqla_table.get_catalog_perm() is None + assert sqla_table.get_schema_perm() == "[my_db].[schema1]" + + +def test_permissions_with_catalog() -> None: + """ + Test permissions when the table with a catalog set. + """ + database = Database(database_name="my_db") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=database, + schema="schema1", + catalog="db1", + id=1, + ) + + assert sqla_table.get_perm() == "[my_db].[my_sqla_table](id:1)" + assert sqla_table.get_catalog_perm() == "[my_db].[db1]" + assert sqla_table.get_schema_perm() == "[my_db].[db1].[schema1]" + + +def test_query_datasources_by_name(mocker: MockerFixture) -> None: + """ + Test the `query_datasources_by_name` method. + """ + db = mocker.patch("superset.connectors.sqla.models.db") + + database = Database(database_name="my_db", id=1) + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=database, + ) + + sqla_table.query_datasources_by_name(database, "my_table") + db.session.query().filter_by.assert_called_with( + database_id=1, + table_name="my_table", + ) + + sqla_table.query_datasources_by_name(database, "my_table", "db1", "schema1") + db.session.query().filter_by.assert_called_with( + database_id=1, + table_name="my_table", + catalog="db1", + schema="schema1", + ) + + +def test_query_datasources_by_permissions(mocker: MockerFixture) -> None: + """ + Test the `query_datasources_by_permissions` method. + """ + db = mocker.patch("superset.connectors.sqla.models.db") + + engine = create_engine("sqlite://") + database = Database(database_name="my_db", id=1) + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=database, + ) + + sqla_table.query_datasources_by_permissions(database, set(), set(), set()) + db.session.query().filter_by.assert_called_with(database_id=1) + clause = db.session.query().filter_by().filter.mock_calls[0].args[0] + assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == "" + + +def test_query_datasources_by_permissions_with_catalog_schema( + mocker: MockerFixture, +) -> None: + """ + Test the `query_datasources_by_permissions` method passing a catalog and schema. + """ + db = mocker.patch("superset.connectors.sqla.models.db") + + engine = create_engine("sqlite://") + database = Database(database_name="my_db", id=1) + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=database, + ) + sqla_table.query_datasources_by_permissions( + database, + {"[my_db].[table1](id:1)"}, + {"[my_db].[db1]"}, + # pass as list to have deterministic order for test + ["[my_db].[db1].[schema1]", "[my_other_db].[schema]"], # type: ignore + ) + clause = db.session.query().filter_by().filter.mock_calls[0].args[0] + assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == ( + "tables.perm IN ('[my_db].[table1](id:1)') OR " + "tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR " + "tables.catalog_perm IN ('[my_db].[db1]')" + ) diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 8b309d573..f3f556e2e 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -2066,3 +2066,101 @@ def test_table_extra_metadata_unauthorized( } ] } + + +def test_catalogs( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `catalogs` endpoint. + """ + database = mocker.MagicMock() + database.get_all_catalog_names.return_value = {"db1", "db2"} + DatabaseDAO = mocker.patch("superset.databases.api.DatabaseDAO") + DatabaseDAO.find_by_id.return_value = database + + security_manager = mocker.patch( + "superset.databases.api.security_manager", + new=mocker.MagicMock(), + ) + security_manager.get_catalogs_accessible_by_user.return_value = {"db2"} + + response = client.get("/api/v1/database/1/catalogs/") + assert response.status_code == 200 + assert response.json == {"result": ["db2"]} + database.get_all_catalog_names.assert_called_with( + cache=database.catalog_cache_enabled, + cache_timeout=database.catalog_cache_timeout, + force=False, + ) + security_manager.get_catalogs_accessible_by_user.assert_called_with( + database, + {"db1", "db2"}, + ) + + response = client.get("/api/v1/database/1/catalogs/?q=(force:!t)") + database.get_all_catalog_names.assert_called_with( + cache=database.catalog_cache_enabled, + cache_timeout=database.catalog_cache_timeout, + force=True, + ) + + +def test_schemas( + mocker: MockFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Test the `schemas` endpoint. + """ + from superset.databases.api import DatabaseRestApi + + database = mocker.MagicMock() + database.get_all_schema_names.return_value = {"schema1", "schema2"} + datamodel = mocker.patch.object(DatabaseRestApi, "datamodel") + datamodel.get.return_value = database + + security_manager = mocker.patch( + "superset.databases.api.security_manager", + new=mocker.MagicMock(), + ) + security_manager.get_schemas_accessible_by_user.return_value = {"schema2"} + + response = client.get("/api/v1/database/1/schemas/") + assert response.status_code == 200 + assert response.json == {"result": ["schema2"]} + database.get_all_schema_names.assert_called_with( + catalog=None, + cache=database.schema_cache_enabled, + cache_timeout=database.schema_cache_timeout, + force=False, + ) + security_manager.get_schemas_accessible_by_user.assert_called_with( + database, + None, + {"schema1", "schema2"}, + ) + + response = client.get("/api/v1/database/1/schemas/?q=(force:!t)") + database.get_all_schema_names.assert_called_with( + catalog=None, + cache=database.schema_cache_enabled, + cache_timeout=database.schema_cache_timeout, + force=True, + ) + + response = client.get("/api/v1/database/1/schemas/?q=(force:!t,catalog:catalog2)") + database.get_all_schema_names.assert_called_with( + catalog="catalog2", + cache=database.schema_cache_enabled, + cache_timeout=database.schema_cache_timeout, + force=True, + ) + security_manager.get_schemas_accessible_by_user.assert_called_with( + database, + "catalog2", + {"schema1", "schema2"}, + ) diff --git a/tests/unit_tests/databases/filters_test.py b/tests/unit_tests/databases/filters_test.py new file mode 100644 index 000000000..a1e51fce2 --- /dev/null +++ b/tests/unit_tests/databases/filters_test.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from flask_appbuilder.models.sqla.interface import SQLAInterface +from pytest_mock import MockerFixture +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from superset.databases.filters import can_access_databases, DatabaseFilter +from superset.extensions import security_manager + + +def test_can_access_databases(mocker: MockerFixture) -> None: + """ + Test the `can_access_databases` function. + """ + mocker.patch.object( + security_manager, + "user_view_menu_names", + side_effect=[ + { + "[my_db].[examples].[public].[table1](id:1)", + "[my_other_db].[examples].[public].[table1](id:2)", + }, + {"[my_db].(id:42)", "[my_other_db].(id:43)"}, + {"[my_db].[examples]", "[my_db].[other]"}, + { + "[my_db].[examples].[information_schema]", + "[my_db].[other].[secret]", + "[third_db].[schema]", + }, + ], + ) + + assert can_access_databases("datasource_access") == {"my_db", "my_other_db"} + assert can_access_databases("database_access") == {"my_db", "my_other_db"} + assert can_access_databases("catalog_access") == {"my_db"} + assert can_access_databases("schema_access") == {"my_db", "third_db"} + + +def test_database_filter_full_db_access(mocker: MockerFixture) -> None: + """ + Test the `DatabaseFilter` class when the user has full database access. + + In this case the query should be returned unmodified. + """ + from superset.models.core import Database + + current_app = mocker.patch("superset.databases.filters.current_app") + current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False} + mocker.patch.object(security_manager, "can_access_all_databases", return_value=True) + + engine = create_engine("sqlite://") + Session = sessionmaker(bind=engine) + session = Session() + query = session.query(Database) + + filter_ = DatabaseFilter("id", SQLAInterface(Database)) + filtered_query = filter_.apply(query, None) + + assert filtered_query == query + + +def test_database_filter(mocker: MockerFixture) -> None: + """ + Test the `DatabaseFilter` class with specific permissions. + """ + from superset.models.core import Database + + current_app = mocker.patch("superset.databases.filters.current_app") + current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False} + mocker.patch.object( + security_manager, + "can_access_all_databases", + return_value=False, + ) + mocker.patch.object( + security_manager, + "user_view_menu_names", + side_effect=[ + # return lists instead of sets to ensure order + ["[my_db].(id:42)", "[my_other_db].(id:43)"], + ["[my_db].[examples]", "[my_db].[other]"], + [ + "[my_db].[examples].[information_schema]", + "[my_db].[other].[secret]", + "[third_db].[schema]", + ], + [ + "[my_db].[examples].[public].[table1](id:1)", + "[my_other_db].[examples].[public].[table1](id:2)", + ], + ], + ) + + engine = create_engine("sqlite://") + Session = sessionmaker(bind=engine) + session = Session() + query = session.query(Database) + + filter_ = DatabaseFilter("id", SQLAInterface(Database)) + filtered_query = filter_.apply(query, None) + + compiled_query = filtered_query.statement.compile( + engine, + compile_kwargs={"literal_binds": True}, + ) + space = " " # pre-commit removes trailing spaces... + assert ( + str(compiled_query) + == f"""SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk{space} +FROM dbs{space} +WHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')""" + ) diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 3bc05ee20..0950dcb43 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -301,3 +301,13 @@ def test_extra_table_metadata(mocker: MockFixture) -> None: ) warnings.warn.assert_called() + + +def test_get_default_catalog(mocker: MockFixture) -> None: + """ + Test the `get_default_catalog` method. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + database = mocker.MagicMock() + assert BaseEngineSpec.get_default_catalog(database) is None diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index bec77b3b7..d955d3ce5 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -175,3 +175,33 @@ SELECT * FROM some_table; str(excinfo.value) == "Users are not allowed to set a search path for security reasons." ) + + +def test_adjust_engine_params() -> None: + """ + Test `adjust_engine_params`. + + The method can be used to adjust the catalog (database) dynamically. + """ + from superset.db_engine_specs.postgres import PostgresEngineSpec + + adjusted = PostgresEngineSpec.adjust_engine_params( + make_url("postgresql://user:password@host:5432/dev"), + {}, + catalog="prod", + ) + assert adjusted == (make_url("postgresql://user:password@host:5432/prod"), {}) + + +def test_get_default_catalog() -> None: + """ + Test `get_default_catalog`. + """ + from superset.db_engine_specs.postgres import PostgresEngineSpec + from superset.models.core import Database + + database = Database( + database_name="postgres", + sqlalchemy_uri="postgresql://user:password@host:5432/dev", + ) + assert PostgresEngineSpec.get_default_catalog(database) == "dev" diff --git a/tests/unit_tests/explore/utils_test.py b/tests/unit_tests/explore/utils_test.py index 9638392a5..2b274106c 100644 --- a/tests/unit_tests/explore/utils_test.py +++ b/tests/unit_tests/explore/utils_test.py @@ -272,6 +272,7 @@ def test_query_no_access(mocker: MockFixture, client) -> None: from superset.models.sql_lab import Query database = mocker.MagicMock() + database.get_default_catalog.return_value = None database.get_default_schema_for_query.return_value = "public" mocker.patch( query_find_by_id, diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index ce3ad1822..5ee521fc0 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -229,3 +229,63 @@ def test_get_prequeries(mocker: MockFixture) -> None: conn.cursor().execute.assert_has_calls( [mocker.call("set a=1"), mocker.call("set b=2")] ) + + +def test_catalog_cache() -> None: + """ + Test the catalog cache. + """ + database = Database( + database_name="db", + sqlalchemy_uri="sqlite://", + extra=json.dumps({"metadata_cache_timeout": {"catalog_cache_timeout": 10}}), + ) + + assert database.catalog_cache_enabled + assert database.catalog_cache_timeout == 10 + + +def test_get_default_catalog() -> None: + """ + Test the `get_default_catalog` method. + """ + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + + assert database.get_default_catalog() == "examples" + + +def test_get_default_schema(mocker: MockFixture) -> None: + """ + Test the `get_default_schema` method. + """ + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + + get_inspector = mocker.patch.object(database, "get_inspector") + with get_inspector() as inspector: + inspector.default_schema_name = "public" + + assert database.get_default_schema("examples") == "public" + get_inspector.assert_called_with(catalog="examples") + + +def test_get_all_catalog_names(mocker: MockFixture) -> None: + """ + Test the `get_all_catalog_names` method. + """ + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + + get_inspector = mocker.patch.object(database, "get_inspector") + with get_inspector() as inspector: + inspector.bind.execute.return_value = [("examples",), ("other",)] + + assert database.get_all_catalog_names(force=True) == {"examples", "other"} + get_inspector.assert_called_with(ssh_tunnel=None) diff --git a/tests/unit_tests/security/manager_test.py b/tests/unit_tests/security/manager_test.py index 7ed32b0ab..033446290 100644 --- a/tests/unit_tests/security/manager_test.py +++ b/tests/unit_tests/security/manager_test.py @@ -26,7 +26,10 @@ from superset.connectors.sqla.models import Database, SqlaTable from superset.exceptions import SupersetSecurityException from superset.extensions import appbuilder from superset.models.slice import Slice -from superset.security.manager import query_context_modified, SupersetSecurityManager +from superset.security.manager import ( + query_context_modified, + SupersetSecurityManager, +) from superset.sql_parse import Table from superset.superset_typing import AdhocMetric from superset.utils.core import override_user @@ -208,6 +211,7 @@ def test_raise_for_access_query_default_schema( SqlaTable.query_datasources_by_name.return_value = [] database = mocker.MagicMock() + database.get_default_catalog.return_value = None database.get_default_schema_for_query.return_value = "public" query = mocker.MagicMock() query.database = database @@ -262,6 +266,7 @@ def test_raise_for_access_jinja_sql(mocker: MockFixture, app_context: None) -> N SqlaTable.query_datasources_by_name.return_value = [] database = mocker.MagicMock() + database.get_default_catalog.return_value = None database.get_default_schema_for_query.return_value = "public" query = mocker.MagicMock() query.database = database @@ -531,3 +536,28 @@ def test_query_context_modified_mixed_chart(mocker: MockFixture) -> None: } query_context.queries = [QueryObject(metrics=requested_metrics)] # type: ignore assert not query_context_modified(query_context) + + +def test_get_catalog_perm() -> None: + """ + Test the `get_catalog_perm` method. + """ + sm = SupersetSecurityManager(appbuilder) + + assert sm.get_catalog_perm("my_db", None) is None + assert sm.get_catalog_perm("my_db", "my_catalog") == "[my_db].[my_catalog]" + + +def test_get_schema_perm() -> None: + """ + Test the `get_schema_perm` method. + """ + sm = SupersetSecurityManager(appbuilder) + + assert sm.get_schema_perm("my_db", None, "my_schema") == "[my_db].[my_schema]" + assert ( + sm.get_schema_perm("my_db", "my_catalog", "my_schema") + == "[my_db].[my_catalog].[my_schema]" + ) + assert sm.get_schema_perm("my_db", None, None) is None + assert sm.get_schema_perm("my_db", "my_catalog", None) is None diff --git a/tests/unit_tests/utils/filters_test.py b/tests/unit_tests/utils/filters_test.py new file mode 100644 index 000000000..b41774bb2 --- /dev/null +++ b/tests/unit_tests/utils/filters_test.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pytest_mock import MockerFixture +from sqlalchemy import create_engine + +from superset.utils.filters import get_dataset_access_filters + + +def test_get_dataset_access_filters(mocker: MockerFixture) -> None: + """ + Test the `get_dataset_access_filters` function. + """ + from superset.connectors.sqla.models import SqlaTable + from superset.extensions import security_manager + + mocker.patch.object( + security_manager, + "get_accessible_databases", + return_value=[1, 3], + ) + mocker.patch.object( + security_manager, + "user_view_menu_names", + side_effect=[ + {"[db].[catalog1].[schema1].[table1](id:1)"}, + {"[db].[catalog1].[schema2]"}, + {"[db].[catalog2]"}, + ], + ) + + clause = get_dataset_access_filters(SqlaTable) + engine = create_engine("sqlite://") + compiled_query = clause.compile(engine, compile_kwargs={"literal_binds": True}) + assert str(compiled_query) == ( + "dbs.id IN (1, 3) " + "OR tables.perm IN ('[db].[catalog1].[schema1].[table1](id:1)') " + "OR tables.catalog_perm IN ('[db].[catalog2]') OR " + "tables.schema_perm IN ('[db].[catalog1].[schema2]')" + ) diff --git a/tests/unit_tests/utils/test_core.py b/tests/unit_tests/utils/test_core.py index a8081e0f3..2ebec87c2 100644 --- a/tests/unit_tests/utils/test_core.py +++ b/tests/unit_tests/utils/test_core.py @@ -29,6 +29,7 @@ from superset.utils.core import ( DateColumn, generic_find_constraint_name, generic_find_fk_constraint_name, + get_datasource_full_name, is_test, normalize_dttm_col, parse_boolean_string, @@ -369,3 +370,29 @@ def test_generic_find_fk_constraint_none_exist(): ) assert result is None + + +def test_get_datasource_full_name(): + """ + Test the `get_datasource_full_name` function. + + This is used to build permissions, so it doesn't really return the datasource full + name. Instead, it returns a fully qualified table name that includes the database + name and schema, with each part wrapped in square brackets. + """ + assert ( + get_datasource_full_name("db", "table", "catalog", "schema") + == "[db].[catalog].[schema].[table]" + ) + + assert get_datasource_full_name("db", "table", None, None) == "[db].[table]" + + assert ( + get_datasource_full_name("db", "table", None, "schema") + == "[db].[schema].[table]" + ) + + assert ( + get_datasource_full_name("db", "table", "catalog", None) + == "[db].[catalog].[table]" + ) diff --git a/tests/unit_tests/views/database/__init__.py b/tests/unit_tests/views/database/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/tests/unit_tests/views/database/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/views/database/mixins_test.py b/tests/unit_tests/views/database/mixins_test.py new file mode 100644 index 000000000..1752d9763 --- /dev/null +++ b/tests/unit_tests/views/database/mixins_test.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pytest_mock import MockerFixture + +from superset.views.database.mixins import DatabaseMixin + + +def test_pre_add_update_with_catalog(mocker: MockerFixture) -> None: + """ + Test the `_pre_add_update` method on a DB with catalog support. + """ + from superset.models.core import Database + + add_permission_view_menu = mocker.patch( + "superset.views.database.mixins.security_manager.add_permission_view_menu" + ) + + database = Database( + database_name="my_db", + id=42, + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + mocker.patch.object( + database, + "get_all_catalog_names", + return_value=["examples", "other"], + ) + mocker.patch.object( + database, + "get_all_schema_names", + side_effect=[ + ["public", "information_schema"], + ["secret"], + ], + ) + + mixin = DatabaseMixin() + mixin._pre_add_update(database) + + add_permission_view_menu.assert_has_calls( + [ + mocker.call("database_access", "[my_db].(id:42)"), + mocker.call("catalog_access", "[my_db].[examples]"), + mocker.call("catalog_access", "[my_db].[other]"), + mocker.call("schema_access", "[my_db].[examples].[public]"), + mocker.call("schema_access", "[my_db].[examples].[information_schema]"), + mocker.call("schema_access", "[my_db].[other].[secret]"), + ], + any_order=True, + )