feat(SIP-95): permissions for catalogs (#28317)

This commit is contained in:
Beto Dealmeida 2024-05-06 11:41:58 -04:00 committed by GitHub
parent 9a339f08a7
commit e90246fd1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 2381 additions and 316 deletions

View File

@ -97,12 +97,37 @@ class CreateDatabaseCommand(BaseCommand):
db.session.commit()
# adding a new database we always want to force refresh schema list
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
for schema in schemas:
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
# add catalog/schema permissions
if database.db_engine_spec.supports_catalog:
catalogs = database.get_all_catalog_names(
cache=False,
ssh_tunnel=ssh_tunnel,
)
for catalog in catalogs:
security_manager.add_permission_view_menu(
"catalog_access",
security_manager.get_catalog_perm(
database.database_name, catalog
),
)
else:
# add a dummy catalog for DBs that don't support them
catalogs = [None]
for catalog in catalogs:
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
),
)
except (
SSHTunnelInvalidError,

View File

@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any, cast
@ -29,7 +32,6 @@ from superset.daos.database import DatabaseDAO
from superset.exceptions import SupersetException
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.utils.core import DatasourceName
logger = logging.getLogger(__name__)
@ -37,8 +39,15 @@ logger = logging.getLogger(__name__)
class TablesDatabaseCommand(BaseCommand):
_model: Database
def __init__(self, db_id: int, schema_name: str, force: bool):
def __init__(
self,
db_id: int,
catalog_name: str | None,
schema_name: str,
force: bool,
):
self._db_id = db_id
self._catalog_name = catalog_name
self._schema_name = schema_name
self._force = force
@ -47,11 +56,11 @@ class TablesDatabaseCommand(BaseCommand):
try:
tables = security_manager.get_datasources_accessible_by_user(
database=self._model,
catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_table_names_in_schema(
catalog=None,
self._model.get_all_table_names_in_schema(
catalog=self._catalog_name,
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
@ -62,11 +71,11 @@ class TablesDatabaseCommand(BaseCommand):
views = security_manager.get_datasources_accessible_by_user(
database=self._model,
catalog=self._catalog_name,
schema=self._schema_name,
datasource_names=sorted(
DatasourceName(*datasource_name)
for datasource_name in self._model.get_all_view_names_in_schema(
catalog=None,
self._model.get_all_view_names_in_schema(
catalog=self._catalog_name,
schema=self._schema_name,
force=self._force,
cache=self._model.table_cache_enabled,
@ -81,11 +90,15 @@ class TablesDatabaseCommand(BaseCommand):
db.session.query(SqlaTable)
.filter(
SqlaTable.database_id == self._model.id,
SqlaTable.catalog == self._catalog_name,
SqlaTable.schema == self._schema_name,
)
.options(
load_only(
SqlaTable.schema, SqlaTable.table_name, SqlaTable.extra
SqlaTable.catalog,
SqlaTable.schema,
SqlaTable.table_name,
SqlaTable.extra,
),
lazyload(SqlaTable.columns),
lazyload(SqlaTable.metrics),

View File

@ -18,10 +18,9 @@
from __future__ import annotations
import logging
from typing import Any, Optional
from typing import Any
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset import is_feature_enabled, security_manager
from superset.commands.base import BaseCommand
@ -50,12 +49,12 @@ logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand):
_model: Optional[Database]
_model: Database | None
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
self._model: Database | None = None
def run(self) -> Model:
self._model = DatabaseDAO.find_by_id(self._model_id)
@ -85,7 +84,7 @@ class UpdateDatabaseCommand(BaseCommand):
)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_schemas(database, original_database_name, ssh_tunnel)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
except SSHTunnelError as ex:
# allow exception to bubble for debugbing information
raise ex
@ -121,67 +120,200 @@ class UpdateDatabaseCommand(BaseCommand):
ssh_tunnel_properties,
).run()
def _refresh_schemas(
def _get_catalog_names(
self,
database: Database,
original_database_name: str,
ssh_tunnel: Optional[SSHTunnel],
) -> None:
ssh_tunnel: SSHTunnel | None,
) -> set[str]:
"""
Add permissions for any new schemas.
Helper method to load catalogs.
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
return database.get_all_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
for schema in schemas:
original_vm = security_manager.get_schema_perm(
def _get_schema_names(
self,
database: Database,
catalog: str | None,
ssh_tunnel: SSHTunnel | None,
) -> set[str]:
"""
Helper method to load schemas.
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
return database.get_all_schema_names(
force=True,
catalog=catalog,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
def _refresh_catalogs(
self,
database: Database,
original_database_name: str,
ssh_tunnel: SSHTunnel | None,
) -> None:
"""
Add permissions for any new catalogs and schemas.
"""
catalogs = (
self._get_catalog_names(database, ssh_tunnel)
if database.db_engine_spec.supports_catalog
else [None]
)
for catalog in catalogs:
schemas = self._get_schema_names(database, catalog, ssh_tunnel)
if catalog:
perm = security_manager.get_catalog_perm(
original_database_name,
catalog,
)
existing_pvm = security_manager.find_permission_view_menu(
"catalog_access",
perm,
)
if not existing_pvm:
# new catalog
security_manager.add_permission_view_menu(
"catalog_access",
security_manager.get_catalog_perm(
database.database_name,
catalog,
),
)
for schema in schemas:
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
),
)
continue
# add possible new schemas in catalog
self._refresh_schemas(
database,
original_database_name,
catalog,
schemas,
)
if original_database_name != database.database_name:
self._rename_database_in_permissions(
database,
original_database_name,
catalog,
schemas,
)
db.session.commit()
def _refresh_schemas(
self,
database: Database,
original_database_name: str,
catalog: str | None,
schemas: set[str],
) -> None:
"""
Add new schemas that don't have permissions yet.
"""
for schema in schemas:
perm = security_manager.get_schema_perm(
original_database_name,
catalog,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
original_vm,
perm,
)
if not existing_pvm:
# new schema
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(database.database_name, schema),
new_name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
continue
security_manager.add_permission_view_menu("schema_access", new_name)
if original_database_name == database.database_name:
continue
def _rename_database_in_permissions(
self,
database: Database,
original_database_name: str,
catalog: str | None,
schemas: set[str],
) -> None:
new_name = security_manager.get_catalog_perm(
database.database_name,
catalog,
)
# rename existing schema permission
existing_pvm.view_menu.name = security_manager.get_schema_perm(
# rename existing catalog permission
if catalog:
perm = security_manager.get_catalog_perm(
original_database_name,
catalog,
)
existing_pvm = security_manager.find_permission_view_menu(
"catalog_access",
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_name
for schema in schemas:
new_name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
# rename existing schema permission
perm = security_manager.get_schema_perm(
original_database_name,
catalog,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = new_name
# rename permissions on datasets and charts
for dataset in DatabaseDAO.get_datasets(
database.id,
catalog=None,
catalog=catalog,
schema=schema,
):
dataset.schema_perm = existing_pvm.view_menu.name
dataset.schema_perm = new_name
for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]:
chart.schema_perm = existing_pvm.view_menu.name
db.session.commit()
chart.schema_perm = new_name
def validate(self) -> None:
exceptions: list[ValidationError] = []
database_name: Optional[str] = self._properties.get("database_name")
if database_name:
# Check database_name uniqueness
if database_name := self._properties.get("database_name"):
if not DatabaseDAO.validate_update_uniqueness(
self._model_id, database_name
self._model_id,
database_name,
):
exceptions.append(DatabaseExistsValidationError())
if exceptions:
raise DatabaseInvalidError(exceptions=exceptions)
raise DatabaseInvalidError(exceptions=[DatabaseExistsValidationError()])

View File

@ -126,7 +126,11 @@ class SqlResultExportCommand(BaseCommand):
}:
# remove extra row from `increased_limit`
limit -= 1
df = self._query.database.get_df(sql, self._query.schema)[:limit]
df = self._query.database.get_df(
sql,
self._query.catalog,
self._query.schema,
)[:limit]
csv_data = csv.df_to_escaped_csv(df, index=False, **config["CSV_EXPORT"])

View File

@ -564,9 +564,9 @@ IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None
#
# Takes as a parameter the common bootstrap payload before transformations.
# Returns a dict containing data that should be added or overridden to the payload.
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[[dict[str, Any]], dict[str, Any]] = ( # noqa: E731
lambda data: {}
) # default: empty dict
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[ # noqa: E731
[dict[str, Any]], dict[str, Any]
] = lambda data: {}
# EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes
# example code for "My custom warm to hot" color scheme

View File

@ -211,6 +211,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
params = Column(String(1000))
perm = Column(String(1000))
schema_perm = Column(String(1000))
catalog_perm = Column(String(1000), nullable=True, default=None)
is_managed_externally = Column(Boolean, nullable=False, default=False)
external_url = Column(Text, nullable=True)
@ -1261,9 +1262,20 @@ class SqlaTable(
anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
return Markup(anchor)
def get_catalog_perm(self) -> str | None:
"""Returns catalog permission if present, database one otherwise."""
return security_manager.get_catalog_perm(
self.database.database_name,
self.catalog,
)
def get_schema_perm(self) -> str | None:
"""Returns schema permission if present, database one otherwise."""
return security_manager.get_schema_perm(self.database, self.schema or None)
return security_manager.get_schema_perm(
self.database.database_name,
self.catalog,
self.schema or None,
)
def get_perm(self) -> str:
"""
@ -1282,7 +1294,10 @@ class SqlaTable(
@property
def full_name(self) -> str:
return utils.get_datasource_full_name(
self.database, self.table_name, schema=self.schema
self.database,
self.table_name,
catalog=self.catalog,
schema=self.schema,
)
@property
@ -1736,7 +1751,10 @@ class SqlaTable(
try:
df = self.database.get_df(
sql, self.schema or None, mutator=assign_column_label
sql,
None,
self.schema or None,
mutator=assign_column_label,
)
except (SupersetErrorException, SupersetErrorsException) as ex:
# SupersetError(s) exception should not be captured; instead, they should
@ -1870,34 +1888,45 @@ class SqlaTable(
cls,
database: Database,
datasource_name: str,
catalog: str | None = None,
schema: str | None = None,
) -> list[SqlaTable]:
query = (
db.session.query(cls)
.filter_by(database_id=database.id)
.filter_by(table_name=datasource_name)
)
filters = {
"database_id": database.id,
"table_name": datasource_name,
}
if catalog:
filters["catalog"] = catalog
if schema:
query = query.filter_by(schema=schema)
return query.all()
filters["schema"] = schema
return db.session.query(cls).filter_by(**filters).all()
@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
database: Database,
permissions: set[str],
catalog_perms: set[str],
schema_perms: set[str],
) -> list[SqlaTable]:
# TODO(hughhhh): add unit test
# remove empty sets from the query, since SQLAlchemy produces horrible SQL for
# Model.column._in({}):
#
# table.column IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1)
filters = [
method.in_(perms)
for method, perms in zip(
(SqlaTable.perm, SqlaTable.schema_perm, SqlaTable.catalog_perm),
(permissions, schema_perms, catalog_perms),
)
if perms
]
return (
db.session.query(cls)
.filter_by(database_id=database.id)
.filter(
or_(
SqlaTable.perm.in_(permissions),
SqlaTable.schema_perm.in_(schema_perms),
)
)
.filter(or_(*filters))
.all()
)

View File

@ -132,6 +132,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = {
"related_objects": "read",
"tables": "read",
"schemas": "read",
"catalogs": "read",
"select_star": "read",
"table_metadata": "read",
"table_metadata_deprecated": "read",

View File

@ -73,10 +73,12 @@ from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO
from superset.databases.decorators import check_table_access
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
from superset.databases.schemas import (
CatalogsResponseSchema,
ColumnarMetadataUploadFilePostSchema,
ColumnarUploadPostSchema,
CSVMetadataUploadFilePostSchema,
CSVUploadPostSchema,
database_catalogs_query_schema,
database_schemas_query_schema,
database_tables_query_schema,
DatabaseConnectionSchema,
@ -146,6 +148,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"table_extra_metadata",
"table_extra_metadata_deprecated",
"select_star",
"catalogs",
"schemas",
"test_connection",
"related_objects",
@ -266,6 +269,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
edit_model_schema = DatabasePutSchema()
apispec_parameter_schemas = {
"database_catalogs_query_schema": database_catalogs_query_schema,
"database_schemas_query_schema": database_schemas_query_schema,
"database_tables_query_schema": database_tables_query_schema,
"get_export_ids_schema": get_export_ids_schema,
@ -273,6 +277,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
openapi_spec_tag = "Database"
openapi_spec_component_schemas = (
CatalogsResponseSchema,
ColumnarUploadPostSchema,
CSVUploadPostSchema,
DatabaseConnectionSchema,
@ -604,6 +609,70 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
)
return self.response_422(message=str(ex))
@expose("/<int:pk>/catalogs/")
@protect()
@safe
@rison(database_catalogs_query_schema)
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".catalogs",
log_to_statsd=False,
)
def catalogs(self, pk: int, **kwargs: Any) -> FlaskResponse:
"""Get all catalogs from a database.
---
get:
summary: Get all catalogs from a database
parameters:
- in: path
schema:
type: integer
name: pk
description: The database id
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/database_catalogs_query_schema'
responses:
200:
description: A List of all catalogs from the database
content:
application/json:
schema:
$ref: "#/components/schemas/CatalogsResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
database = DatabaseDAO.find_by_id(pk)
if not database:
return self.response_404()
try:
catalogs = database.get_all_catalog_names(
cache=database.catalog_cache_enabled,
cache_timeout=database.catalog_cache_timeout or None,
force=kwargs["rison"].get("force", False),
)
catalogs = security_manager.get_catalogs_accessible_by_user(
database,
catalogs,
)
return self.response(200, result=list(catalogs))
except OperationalError:
return self.response(
500,
message="There was an error connecting to the database",
)
except SupersetException as ex:
return self.response(ex.status, message=ex.message)
@expose("/<int:pk>/schemas/")
@protect()
@safe
@ -650,13 +719,19 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
if not database:
return self.response_404()
try:
catalog = kwargs["rison"].get("catalog")
schemas = database.get_all_schema_names(
catalog=catalog,
cache=database.schema_cache_enabled,
cache_timeout=database.schema_cache_timeout or None,
force=kwargs["rison"].get("force", False),
)
schemas = security_manager.get_schemas_accessible_by_user(database, schemas)
return self.response(200, result=schemas)
schemas = security_manager.get_schemas_accessible_by_user(
database,
catalog,
schemas,
)
return self.response(200, result=list(schemas))
except OperationalError:
return self.response(
500, message="There was an error connecting to the database"
@ -718,10 +793,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500'
"""
force = kwargs["rison"].get("force", False)
catalog_name = kwargs["rison"].get("catalog_name")
schema_name = kwargs["rison"].get("schema_name", "")
try:
command = TablesDatabaseCommand(pk, schema_name, force)
command = TablesDatabaseCommand(pk, catalog_name, schema_name, force)
payload = command.run()
return self.response(200, **payload)
except DatabaseNotFoundError:

View File

@ -28,11 +28,12 @@ from superset.models.core import Database
from superset.views.base import BaseFilter
def can_access_databases(
view_menu_name: str,
) -> set[str]:
def can_access_databases(view_menu_name: str) -> set[str]:
"""
Return names of databases available in `view_menu_name`.
"""
return {
security_manager.unpack_database_and_schema(vm).database
vm.split(".")[0][1:-1]
for vm in security_manager.user_view_menu_names(view_menu_name)
}
@ -56,17 +57,21 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods
# We can proceed with default filtering now
if security_manager.can_access_all_databases():
return query
database_perms = security_manager.user_view_menu_names("database_access")
schema_access_databases = can_access_databases("schema_access")
database_perms = security_manager.user_view_menu_names("database_access")
catalog_access_databases = can_access_databases("catalog_access")
schema_access_databases = can_access_databases("schema_access")
datasource_access_databases = can_access_databases("datasource_access")
database_names = sorted(
catalog_access_databases
| schema_access_databases
| datasource_access_databases
)
return query.filter(
or_(
self.model.perm.in_(database_perms),
self.model.database_name.in_(
[*schema_access_databases, *datasource_access_databases]
),
self.model.database_name.in_(database_names),
)
)

View File

@ -56,6 +56,14 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri
from superset.utils.core import markdown, parse_ssl_cert
database_schemas_query_schema = {
"type": "object",
"properties": {
"force": {"type": "boolean"},
"catalog": {"type": "string"},
},
}
database_catalogs_query_schema = {
"type": "object",
"properties": {"force": {"type": "boolean"}},
}
@ -65,6 +73,7 @@ database_tables_query_schema = {
"properties": {
"force": {"type": "boolean"},
"schema_name": {"type": "string"},
"catalog_name": {"type": "string"},
},
"required": ["schema_name"],
}
@ -712,6 +721,12 @@ class SchemasResponseSchema(Schema):
)
class CatalogsResponseSchema(Schema):
result = fields.List(
fields.String(metadata={"description": "A database catalog name"})
)
class DatabaseTablesResponse(Schema):
extra = fields.Dict(
metadata={"description": "Extra data used to specify column metadata"}

View File

@ -404,6 +404,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# Does the DB support catalogs? A catalog here is a group of schemas, and has
# different names depending on the DB: BigQuery calles it a "project", Postgres calls
# it a "database", Trino calls it a "catalog", etc.
#
# When this is changed to true in a DB engine spec it MUST support the
# `get_default_catalog` and `get_catalog_names` methods. In addition, you MUST write
# a database migration updating any existing schema permissions.
supports_catalog = False
# Can the catalog be changed on a per-query basis?
@ -638,10 +642,20 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return driver in cls.drivers
@classmethod
def get_default_catalog(
cls,
database: Database, # pylint: disable=unused-argument
) -> str | None:
"""
Return the default catalog for a given database.
"""
return None
@classmethod
def get_default_schema(cls, database: Database, catalog: str | None) -> str | None:
"""
Return the default schema in a given database.
Return the default schema for a catalog in a given database.
"""
with database.get_inspector(catalog=catalog) as inspector:
return inspector.default_schema_name
@ -1412,24 +1426,24 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
inspector: Inspector,
) -> list[str]:
) -> set[str]:
"""
Get all catalogs from database.
This needs to be implemented per database, since SQLAlchemy doesn't offer an
abstraction.
"""
return []
return set()
@classmethod
def get_schema_names(cls, inspector: Inspector) -> list[str]:
def get_schema_names(cls, inspector: Inspector) -> set[str]:
"""
Get all schemas from database
:param inspector: SqlAlchemy inspector
:return: All schemas in the database
"""
return sorted(inspector.get_schema_names())
return set(inspector.get_schema_names())
@classmethod
def get_table_names( # pylint: disable=unused-argument

View File

@ -127,7 +127,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
allows_hidden_cc_in_orderby = True
supports_catalog = True
supports_catalog = False
"""
https://www.python.org/dev/peps/pep-0249/#arraysize
@ -464,7 +464,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
cls,
database: Database,
inspector: Inspector,
) -> list[str]:
) -> set[str]:
"""
Get all catalogs.
@ -475,7 +475,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
client = cls._get_client(engine)
projects = client.list_projects()
return sorted(project.project_id for project in projects)
return {project.project_id for project in projects}
@classmethod
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:

View File

@ -278,7 +278,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec):
@classmethod
def get_function_names(cls, database: Database) -> list[str]:
# pylint: disable=import-outside-toplevel,import-error
# pylint: disable=import-outside-toplevel, import-error
from clickhouse_connect.driver.exceptions import ClickHouseError
if cls._function_names:
@ -340,7 +340,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec):
def validate_parameters(
cls, properties: BasicPropertiesType
) -> list[SupersetError]:
# pylint: disable=import-outside-toplevel,import-error
# pylint: disable=import-outside-toplevel, import-error
from clickhouse_connect.driver import default_port
parameters = properties.get("parameters", {})

View File

@ -74,13 +74,12 @@ class ImpalaEngineSpec(BaseEngineSpec):
return None
@classmethod
def get_schema_names(cls, inspector: Inspector) -> list[str]:
schemas = [
def get_schema_names(cls, inspector: Inspector) -> set[str]:
return {
row[0]
for row in inspector.engine.execute("SHOW SCHEMAS")
if not row[0].startswith("_")
]
return schemas
}
@classmethod
def has_implicit_cancel(cls) -> bool:

View File

@ -101,8 +101,6 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
engine = ""
engine_name = "PostgreSQL"
supports_catalog = True
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: "DATE_TRUNC('second', {col})",
@ -199,7 +197,10 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
engine = "postgresql"
engine_aliases = {"postgres"}
supports_dynamic_schema = True
supports_catalog = True
supports_dynamic_catalog = True
default_driver = "psycopg2"
sqlalchemy_uri_placeholder = (
@ -296,6 +297,29 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
return super().get_default_schema_for_query(database, query)
@classmethod
def adjust_engine_params(
cls,
uri: URL,
connect_args: dict[str, Any],
catalog: str | None = None,
schema: str | None = None,
) -> tuple[URL, dict[str, Any]]:
"""
Set the catalog (database).
"""
if catalog:
uri = uri.set(database=catalog)
return uri, connect_args
@classmethod
def get_default_catalog(cls, database: Database) -> str | None:
"""
Return the default catalog for a given database.
"""
return database.url_object.database
@classmethod
def get_prequeries(
cls,
@ -346,13 +370,13 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
) -> list[str]:
) -> set[str]:
"""
Return all catalogs.
In Postgres, a catalog is called a "database".
"""
return sorted(
return {
catalog
for (catalog,) in inspector.bind.execute(
"""
@ -360,7 +384,7 @@ SELECT datname FROM pg_database
WHERE datistemplate = false;
"""
)
)
}
@classmethod
def get_table_names(

View File

@ -648,6 +648,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
engine_name = "Presto"
allows_alias_to_source_column = False
supports_catalog = False
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__(
@ -815,11 +817,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
) -> list[str]:
) -> set[str]:
"""
Get all catalogs.
"""
return [catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")]
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
@classmethod
def _create_column_info(

View File

@ -85,7 +85,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
sqlalchemy_uri_placeholder = "snowflake://"
supports_dynamic_schema = True
supports_catalog = True
supports_catalog = False
_time_grain_expressions = {
None: "{col}",
@ -174,18 +174,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
cls,
database: "Database",
inspector: Inspector,
) -> list[str]:
) -> set[str]:
"""
Return all catalogs.
In Snowflake, a catalog is called a "database".
"""
return sorted(
return {
catalog
for (catalog,) in inspector.bind.execute(
"SELECT DATABASE_NAME from information_schema.databases"
)
)
}
@classmethod
def epoch_to_dttm(cls) -> str:

View File

@ -270,11 +270,6 @@ class SupersetShillelaghAdapter(Adapter):
self.schema = parts.pop(-1) if parts else None
self.catalog = parts.pop(-1) if parts else None
if self.catalog:
# TODO (betodealmeida): when SIP-95 is implemented we should check to see if
# the database has multi-catalog enabled, and if so, give access.
raise NotImplementedError("Catalogs are not currently supported")
# If the table has a single integer primary key we use that as the row ID in order
# to perform updates and deletes. Otherwise we can only do inserts and selects.
self._rowid: str | None = None

View File

@ -0,0 +1,116 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from alembic import op
from superset import db, security_manager
from superset.daos.database import DatabaseDAO
from superset.models.core import Database
logger = logging.getLogger(__name__)
def upgrade_schema_perms(engine: str | None = None) -> None:
"""
Update schema permissions to include the catalog part.
Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
introduction of catalogs, any existing permissions need to be renamed to include the
catalog: `[db].[catalog].[schema]`.
"""
bind = op.get_bind()
session = db.Session(bind=bind)
for database in session.query(Database).all():
db_engine_spec = database.db_engine_spec
if (
engine and db_engine_spec.engine != engine
) or not db_engine_spec.supports_catalog:
continue
catalog = database.get_default_catalog()
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
perm = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
session.commit()
def downgrade_schema_perms(engine: str | None = None) -> None:
"""
Update schema permissions to not have the catalog part.
Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
introduction of catalogs, any existing permissions need to be renamed to include the
catalog: `[db].[catalog].[schema]`.
This helped function reverts the process.
"""
bind = op.get_bind()
session = db.Session(bind=bind)
for database in session.query(Database).all():
db_engine_spec = database.db_engine_spec
if (
engine and db_engine_spec.engine != engine
) or not db_engine_spec.supports_catalog:
continue
catalog = database.get_default_catalog()
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
for schema in database.get_all_schema_names(
catalog=catalog,
cache=False,
ssh_tunnel=ssh_tunnel,
):
perm = security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
)
existing_pvm = security_manager.find_permission_view_menu(
"schema_access",
perm,
)
if existing_pvm:
existing_pvm.view_menu.name = security_manager.get_schema_perm(
database.database_name,
None,
schema,
)
session.commit()

View File

@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Add catalog_perm to tables
Revision ID: 58d051681a3b
Revises: 4a33124c18ad
Create Date: 2024-05-01 10:52:31.458433
"""
import sqlalchemy as sa
from alembic import op
from superset.migrations.shared.catalogs import (
downgrade_schema_perms,
upgrade_schema_perms,
)
# revision identifiers, used by Alembic.
revision = "58d051681a3b"
down_revision = "4a33124c18ad"
def upgrade():
op.add_column(
"tables",
sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
)
op.add_column(
"slices",
sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
)
upgrade_schema_perms(engine="postgresql")
def downgrade():
op.drop_column("slices", "catalog_perm")
op.drop_column("tables", "catalog_perm")
downgrade_schema_perms(engine="postgresql")

View File

@ -78,7 +78,7 @@ from superset.sql_parse import Table
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
from superset.utils import cache as cache_util, core as utils
from superset.utils.backports import StrEnum
from superset.utils.core import get_username
from superset.utils.core import DatasourceName, get_username
from superset.utils.oauth2 import get_oauth2_access_token
config = app.config
@ -313,6 +313,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
def metadata_cache_timeout(self) -> dict[str, Any]:
return self.get_extra().get("metadata_cache_timeout", {})
@property
def catalog_cache_enabled(self) -> bool:
return "catalog_cache_timeout" in self.metadata_cache_timeout
@property
def catalog_cache_timeout(self) -> int | None:
return self.metadata_cache_timeout.get("catalog_cache_timeout")
@property
def schema_cache_enabled(self) -> bool:
return "schema_cache_timeout" in self.metadata_cache_timeout
@ -549,6 +557,18 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
yield conn
def get_default_catalog(self) -> str | None:
"""
Return the default configured catalog for the database.
"""
return self.db_engine_spec.get_default_catalog(self)
def get_default_schema(self, catalog: str | None) -> str | None:
"""
Return the default schema for the database.
"""
return self.db_engine_spec.get_default_schema(self, catalog)
def get_default_schema_for_query(self, query: Query) -> str | None:
"""
Return the default schema for a given query.
@ -706,19 +726,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
key="db:{self.id}:schema:{schema}:table_list",
cache=cache_manager.cache,
)
def get_all_table_names_in_schema( # pylint: disable=unused-argument
def get_all_table_names_in_schema(
self,
catalog: str | None,
schema: str,
cache: bool = False,
cache_timeout: int | None = None,
force: bool = False,
) -> set[tuple[str, str]]:
) -> set[DatasourceName]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
:param catalog: optional catalog name
:param schema: schema name
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
@ -728,7 +746,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
try:
with self.get_inspector(catalog=catalog, schema=schema) as inspector:
return {
(table, schema)
DatasourceName(table, schema, catalog)
for table in self.db_engine_spec.get_table_names(
database=self,
inspector=inspector,
@ -742,19 +760,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
key="db:{self.id}:schema:{schema}:view_list",
cache=cache_manager.cache,
)
def get_all_view_names_in_schema( # pylint: disable=unused-argument
def get_all_view_names_in_schema(
self,
catalog: str | None,
schema: str,
cache: bool = False,
cache_timeout: int | None = None,
force: bool = False,
) -> set[tuple[str, str]]:
) -> set[DatasourceName]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
:param catalog: optional catalog name
:param schema: schema name
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
@ -764,7 +780,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
try:
with self.get_inspector(catalog=catalog, schema=schema) as inspector:
return {
(view, schema)
DatasourceName(view, schema, catalog)
for view in self.db_engine_spec.get_view_names(
database=self,
inspector=inspector,
@ -792,22 +808,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
key="db:{self.id}:schema_list",
cache=cache_manager.cache,
)
def get_all_schema_names( # pylint: disable=unused-argument
def get_all_schema_names(
self,
*,
catalog: str | None = None,
cache: bool = False,
cache_timeout: int | None = None,
force: bool = False,
ssh_tunnel: SSHTunnel | None = None,
) -> list[str]:
"""Parameters need to be passed as keyword arguments.
) -> set[str]:
"""
Return the schemas in a given database
For unused parameters, they are referenced in
cache_util.memoized_func decorator.
:param cache: whether cache is enabled for the function
:param cache_timeout: timeout in seconds for the cache
:param force: whether to force refresh the cache
:param catalog: override default catalog
:param ssh_tunnel: SSH tunnel information needed to establish a connection
:return: schema list
"""
try:
@ -819,6 +830,27 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@cache_util.memoized_func(
key="db:{self.id}:catalog_list",
cache=cache_manager.cache,
)
def get_all_catalog_names(
self,
*,
ssh_tunnel: SSHTunnel | None = None,
) -> set[str]:
"""
Return the catalogs in a given database
:param ssh_tunnel: SSH tunnel information needed to establish a connection
:return: catalog list
"""
try:
with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector:
return self.db_engine_spec.get_catalog_names(self, inspector)
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@property
def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]:
url = make_url_safe(self.sqlalchemy_uri_decrypted)

View File

@ -209,7 +209,10 @@ class ImportExportMixin:
"""Get a mapping of foreign name to the local name of foreign keys"""
parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
if parent_rel:
return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs} # noqa: E741
return {
local.name: remote.name
for (local, remote) in parent_rel.local_remote_pairs
}
return {}
@classmethod
@ -772,6 +775,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def database(self) -> "Database":
raise NotImplementedError()
@property
def catalog(self) -> str:
raise NotImplementedError()
@property
def schema(self) -> str:
raise NotImplementedError()
@ -1025,7 +1032,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return df
try:
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
df = self.database.get_df(
sql,
self.catalog,
self.schema,
mutator=assign_column_label,
)
except Exception as ex: # pylint: disable=broad-except
df = pd.DataFrame()
status = QueryStatus.FAILED

View File

@ -84,6 +84,7 @@ class Slice( # pylint: disable=too-many-public-methods
cache_timeout = Column(Integer)
perm = Column(String(1000))
schema_perm = Column(String(1000))
catalog_perm = Column(String(1000), nullable=True, default=None)
# the last time a user has saved the chart, changed_on is referencing
# when the database row was last written
last_saved_at = Column(DateTime, nullable=True)

View File

@ -22,7 +22,7 @@ import logging
import re
import time
from collections import defaultdict
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING
from flask import current_app, Flask, g, Request
from flask_appbuilder import Model
@ -67,7 +67,7 @@ from superset.security.guest_token import (
GuestTokenUser,
GuestUser,
)
from superset.sql_parse import extract_tables_from_jinja_sql
from superset.sql_parse import extract_tables_from_jinja_sql, Table
from superset.superset_typing import Metric
from superset.utils.core import (
DatasourceName,
@ -89,7 +89,6 @@ if TYPE_CHECKING:
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query
from superset.sql_parse import Table
from superset.viz import BaseViz
logger = logging.getLogger(__name__)
@ -97,8 +96,9 @@ logger = logging.getLogger(__name__)
DATABASE_PERM_REGEX = re.compile(r"^\[.+\]\.\(id\:(?P<id>\d+)\)$")
class DatabaseAndSchema(NamedTuple):
class DatabaseCatalogSchema(NamedTuple):
database: str
catalog: Optional[str]
schema: str
@ -346,17 +346,51 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return self.get_guest_user_from_request(request)
return None
def get_catalog_perm(
self,
database: str,
catalog: Optional[str] = None,
) -> Optional[str]:
"""
Return the database specific catalog permission.
:param database: The Superset database or database name
:param catalog: The database catalog name
:return: The database specific schema permission
"""
if catalog is None:
return None
return f"[{database}].[{catalog}]"
def get_schema_perm(
self, database: Union["Database", str], schema: Optional[str] = None
self,
database: str,
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> Optional[str]:
"""
Return the database specific schema permission.
:param database: The Superset database or database name
:param schema: The Superset schema name
Catalogs were added in SIP-95, and not all databases support them. Because of
this, the format used for permissions is different depending on whether a
catalog is passed or not:
[database].[schema]
[database].[catalog].[schema]
:param database: The database name
:param catalog: The database catalog name
:param schema: The database schema name
:return: The database specific schema permission
"""
return f"[{database}].[{schema}]" if schema else None
if schema is None:
return None
if catalog:
return f"[{database}].[{catalog}].[{schema}]"
return f"[{database}].[{schema}]"
@staticmethod
def get_database_perm(database_id: int, database_name: str) -> Optional[str]:
@ -370,13 +404,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
) -> Optional[str]:
return f"[{database_name}].[{dataset_name}](id:{dataset_id})"
def unpack_database_and_schema(self, schema_permission: str) -> DatabaseAndSchema:
# [database_name].[schema|table]
schema_name = schema_permission.split(".")[1][1:-1]
database_name = schema_permission.split(".")[0][1:-1]
return DatabaseAndSchema(database_name, schema_name)
def can_access(self, permission_name: str, view_name: str) -> bool:
"""
Return True if the user can access the FAB permission/view, False otherwise.
@ -436,6 +463,17 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
or self.can_access("database_access", database.perm) # type: ignore
)
def can_access_catalog(self, database: "Database", catalog: str) -> bool:
"""
Return if the user can access the specified catalog.
"""
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
return bool(
self.can_access_all_datasources()
or self.can_access_database(database)
or (catalog_perm and self.can_access("catalog_access", catalog_perm))
)
def can_access_schema(self, datasource: "BaseDatasource") -> bool:
"""
Return True if the user can access the schema associated with specified
@ -448,6 +486,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return (
self.can_access_all_datasources()
or self.can_access_database(datasource.database)
or self.can_access_catalog(datasource.database, datasource.catalog)
or self.can_access("schema_access", datasource.schema_perm or "")
)
@ -706,55 +745,150 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
]
def get_schemas_accessible_by_user(
self, database: "Database", schemas: list[str], hierarchical: bool = True
) -> list[str]:
self,
database: "Database",
catalog: Optional[str],
schemas: set[str],
hierarchical: bool = True,
) -> set[str]:
"""
Return the list of SQL schemas accessible by the user.
Returned a filtered list of the schemas accessible by the user.
If not catalog is specified, the default catalog is used.
:param database: The SQL database
:param schemas: The list of eligible SQL schemas
:param catalog: An optional database catalog
:param schemas: A set of candidate schemas
:param hierarchical: Whether to check using the hierarchical permission logic
:returns: The list of accessible SQL schemas
:returns: The set of accessible database schemas
"""
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable
if hierarchical and self.can_access_database(database):
if hierarchical and (
self.can_access_database(database)
or (catalog and self.can_access_catalog(database, catalog))
):
return schemas
# schema_access
accessible_schemas = {
self.unpack_database_and_schema(s).schema
for s in self.user_view_menu_names("schema_access")
if s.startswith(f"[{database}].")
}
accessible_schemas: set[str] = set()
schema_access = self.user_view_menu_names("schema_access")
default_catalog = database.get_default_catalog()
default_schema = database.get_default_schema(default_catalog)
for perm in schema_access:
parts = [part[1:-1] for part in perm.split(".")]
if parts[0] != database.database_name:
continue
# [database].[schema] matches when no catalog is specified, or when the user
# specifies the default catalog
if len(parts) == 2 and (catalog is None or catalog == default_catalog):
accessible_schemas.add(parts[1])
# [database].[catalog].[schema] matches when the catalog is equal to the
# requested catalog or, when no catalog specified, it's equal to the default
# catalog.
elif len(parts) == 3 and parts[1] == (catalog or default_catalog):
accessible_schemas.add(parts[2])
# datasource_access
if perms := self.user_view_menu_names("datasource_access"):
tables = (
self.get_session.query(SqlaTable.schema)
.filter(SqlaTable.database_id == database.id)
.filter(SqlaTable.schema.isnot(None))
.filter(SqlaTable.schema != "")
.filter(or_(SqlaTable.perm.in_(perms)))
.distinct()
)
accessible_schemas.update([table.schema for table in tables])
accessible_schemas.update(
{
table.schema or default_schema # type: ignore
for table in tables
if (table.schema or default_schema)
}
)
return [s for s in schemas if s in accessible_schemas]
return schemas & accessible_schemas
def get_catalogs_accessible_by_user(
self,
database: "Database",
catalogs: set[str],
hierarchical: bool = True,
) -> set[str]:
"""
Returned a filtered list of the catalogs accessible by the user.
:param database: The SQL database
:param catalogs: A set of candidate catalogs
:param hierarchical: Whether to check using the hierarchical permission logic
:returns: The set of accessible database catalogs
"""
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable
if hierarchical and self.can_access_database(database):
return catalogs
# catalog access
accessible_catalogs: set[str] = set()
catalog_access = self.user_view_menu_names("catalog_access")
default_catalog = database.get_default_catalog()
for perm in catalog_access:
parts = [part[1:-1] for part in perm.split(".")]
if parts[0] == database.database_name:
accessible_catalogs.add(parts[1])
# schema access
schema_access = self.user_view_menu_names("schema_access")
for perm in schema_access:
parts = [part[1:-1] for part in perm.split(".")]
if parts[0] != database.database_name:
continue
if len(parts) == 2 and default_catalog:
accessible_catalogs.add(default_catalog)
elif len(parts) == 3:
accessible_catalogs.add(parts[2])
# datasource_access
if perms := self.user_view_menu_names("datasource_access"):
tables = (
self.get_session.query(SqlaTable.schema)
.filter(SqlaTable.database_id == database.id)
.filter(or_(SqlaTable.perm.in_(perms)))
.distinct()
)
accessible_catalogs.update(
{
table.catalog or default_catalog # type: ignore
for table in tables
if (table.catalog or default_catalog)
}
)
return catalogs & accessible_catalogs
def get_datasources_accessible_by_user( # pylint: disable=invalid-name
self,
database: "Database",
datasource_names: list[DatasourceName],
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> list[DatasourceName]:
"""
Return the list of SQL tables accessible by the user.
Filter list of SQL tables to the ones accessible by the user.
When catalog and/or schema are specified, it's assumed that all datasources in
`datasource_names` are in the given catalog/schema.
:param database: The SQL database
:param datasource_names: The list of eligible SQL tables w/ schema
:param catalog: The fallback SQL catalog if not present in the table name
:param schema: The fallback SQL schema if not present in the table name
:returns: The list of accessible SQL tables w/ schema
"""
@ -764,22 +898,34 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if self.can_access_database(database):
return datasource_names
if catalog:
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
if catalog_perm and self.can_access("catalog_access", catalog_perm):
return datasource_names
if schema:
schema_perm = self.get_schema_perm(database, schema)
schema_perm = self.get_schema_perm(database, catalog, schema)
if schema_perm and self.can_access("schema_access", schema_perm):
return datasource_names
user_perms = self.user_view_menu_names("datasource_access")
catalog_perms = self.user_view_menu_names("catalog_access")
schema_perms = self.user_view_menu_names("schema_access")
user_datasources = SqlaTable.query_datasources_by_permissions(
database, user_perms, schema_perms
)
if schema:
names = {d.table_name for d in user_datasources if d.schema == schema}
return [d for d in datasource_names if d.table in names]
user_datasources = {
DatasourceName(table.table_name, table.schema, table.catalog)
for table in SqlaTable.query_datasources_by_permissions(
database,
user_perms,
catalog_perms,
schema_perms,
)
}
full_names = {d.full_name for d in user_datasources}
return [d for d in datasource_names if f"[{database}].[{d}]" in full_names]
return [
datasource
for datasource in datasource_names
if datasource in user_datasources
]
def merge_perm(self, permission_name: str, view_menu_name: str) -> None:
"""
@ -844,6 +990,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
for datasource in datasources:
merge_pv("datasource_access", datasource.get_perm())
merge_pv("schema_access", datasource.get_schema_perm())
merge_pv("catalog_access", datasource.get_catalog_perm())
logger.info("Creating missing database permissions.")
databases = self.get_session.query(models.Database).all()
@ -1212,7 +1359,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
self.get_session.query(self.permissionview_model)
.join(self.permission_model)
.join(self.viewmenu_model)
.filter(self.permission_model.name == "schema_access")
.filter(
or_(
self.permission_model.name == "schema_access",
self.permission_model.name == "catalog_access",
)
)
.filter(self.viewmenu_model.name.like(f"[{database_name}].[%]"))
.all()
)
@ -1399,18 +1551,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
.values(perm=dataset_perm)
)
# update catalog and schema perms
values: dict[str, Optional[str]] = {}
if target.schema:
dataset_schema_perm = self.get_schema_perm(
database.database_name, target.schema
database.database_name,
target.catalog,
target.schema,
)
self._insert_pvm_on_sqla_event(
mapper, connection, "schema_access", dataset_schema_perm
mapper,
connection,
"schema_access",
dataset_schema_perm,
)
target.schema_perm = dataset_schema_perm
values["schema_perm"] = dataset_schema_perm
if target.catalog:
dataset_catalog_perm = self.get_catalog_perm(
database.database_name,
target.catalog,
)
self._insert_pvm_on_sqla_event(
mapper,
connection,
"catalog_access",
dataset_catalog_perm,
)
target.catalog_perm = dataset_catalog_perm
values["catalog_perm"] = dataset_catalog_perm
if values:
connection.execute(
dataset_table.update()
.where(dataset_table.c.id == target.id)
.values(schema_perm=dataset_schema_perm)
.values(**values)
)
def dataset_after_delete(
@ -1467,6 +1644,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
table.select().where(table.c.id == target.id)
).one()
current_db_id = current_dataset.database_id
current_catalog = current_dataset.catalog
current_schema = current_dataset.schema
current_table_name = current_dataset.table_name
@ -1479,14 +1657,21 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
mapper, connection, target.perm, new_dataset_vm_name, target
)
# Updates schema permissions
new_dataset_schema_name = self.get_schema_perm(
target.database.database_name, target.schema
# Updates catalog/schema permissions
dataset_catalog_name = self.get_catalog_perm(
target.database.database_name,
target.catalog,
)
self._update_dataset_schema_perm(
dataset_schema_name = self.get_schema_perm(
target.database.database_name,
target.catalog,
target.schema,
)
self._update_dataset_catalog_schema_perm(
mapper,
connection,
new_dataset_schema_name,
dataset_catalog_name,
dataset_schema_name,
target,
)
@ -1502,23 +1687,32 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
mapper, connection, old_dataset_vm_name, new_dataset_vm_name, target
)
# When schema changes
if current_schema != target.schema:
new_dataset_schema_name = self.get_schema_perm(
target.database.database_name, target.schema
# When catalog/schema change
if current_catalog != target.catalog or current_schema != target.schema:
dataset_catalog_name = self.get_catalog_perm(
target.database.database_name,
target.catalog,
)
self._update_dataset_schema_perm(
dataset_schema_name = self.get_schema_perm(
target.database.database_name,
target.catalog,
target.schema,
)
self._update_dataset_catalog_schema_perm(
mapper,
connection,
new_dataset_schema_name,
dataset_catalog_name,
dataset_schema_name,
target,
)
def _update_dataset_schema_perm(
# pylint: disable=invalid-name, too-many-arguments
def _update_dataset_catalog_schema_perm(
self,
mapper: Mapper,
connection: Connection,
new_schema_permission_name: Optional[str],
catalog_permission_name: Optional[str],
schema_permission_name: Optional[str],
target: "SqlaTable",
) -> None:
"""
@ -1530,11 +1724,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
:param mapper: The SQLA event mapper
:param connection: The SQLA connection
:param new_schema_permission_name: The new schema permission name that changed
:param catalog_permission_name: The new catalog permission name that changed
:param schema_permission_name: The new schema permission name that changed
:param target: Dataset that was updated
:return:
"""
logger.info("Updating schema perm, new: %s", new_schema_permission_name)
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
SqlaTable,
)
@ -1545,18 +1739,30 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
sqlatable_table = SqlaTable.__table__ # pylint: disable=no-member
chart_table = Slice.__table__ # pylint: disable=no-member
# insert new schema PVM if it does not exist
# insert new PVMs if they don't not exist
self._insert_pvm_on_sqla_event(
mapper, connection, "schema_access", new_schema_permission_name
mapper,
connection,
"catalog_access",
catalog_permission_name,
)
self._insert_pvm_on_sqla_event(
mapper,
connection,
"schema_access",
schema_permission_name,
)
# Update dataset (SqlaTable schema_perm field)
# Update dataset
connection.execute(
sqlatable_table.update()
.where(
sqlatable_table.c.id == target.id,
)
.values(schema_perm=new_schema_permission_name)
.values(
catalog_perm=catalog_permission_name,
schema_perm=schema_permission_name,
)
)
# Update charts (Slice schema_perm field)
@ -1566,7 +1772,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
chart_table.c.datasource_id == target.id,
chart_table.c.datasource_type == DatasourceType.TABLE,
)
.values(schema_perm=new_schema_permission_name)
.values(
catalog_perm=catalog_permission_name,
schema_perm=schema_permission_name,
)
)
def _update_dataset_perm( # pylint: disable=too-many-arguments
@ -1923,7 +2132,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
table: Optional["Table"] = None,
viz: Optional["BaseViz"] = None,
sql: Optional[str] = None,
catalog: Optional[str] = None, # pylint: disable=unused-argument
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> None:
"""
@ -1947,7 +2156,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query
from superset.sql_parse import Table
from superset.utils.core import shortid
if sql and database:
@ -1955,6 +2163,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
database=database,
sql=sql,
schema=schema,
catalog=catalog,
client_id=shortid()[:10],
user_id=get_user_id(),
)
@ -1970,9 +2179,23 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return
if query:
# Getting the default schema for a query is hard. Users can select the
# schema in SQL Lab, but there's no guarantee that the query actually
# will run in that schema. Each DB engine spec needs to implement the
# necessary logic to enforce that the query runs in the selected schema.
# If the DB engine spec doesn't implement the logic the schema is read
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
# inspector to read it.
default_schema = database.get_default_schema_for_query(query)
# Determining the default catalog is much easier, because DB engine
# specs need explicit support for catalogs.
default_catalog = database.get_default_catalog()
tables = {
Table(table_.table, table_.schema or default_schema)
Table(
table_.table,
table_.schema or default_schema,
table_.catalog or default_catalog,
)
for table_ in extract_tables_from_jinja_sql(query.sql, database)
}
elif table:
@ -1981,21 +2204,36 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
denied = set()
for table_ in tables:
schema_perm = self.get_schema_perm(database, schema=table_.schema)
catalog_perm = self.get_catalog_perm(
database.database_name,
table_.catalog,
)
if catalog_perm and self.can_access("catalog_access", catalog_perm):
continue
if not (schema_perm and self.can_access("schema_access", schema_perm)):
datasources = SqlaTable.query_datasources_by_name(
database, table_.table, schema=table_.schema
)
schema_perm = self.get_schema_perm(
database,
table_.catalog,
table_.schema,
)
if schema_perm and self.can_access("schema_access", schema_perm):
continue
# Access to any datasource is suffice.
for datasource_ in datasources:
if self.can_access(
"datasource_access", datasource_.perm
) or self.is_owner(datasource_):
break
else:
denied.add(table_)
datasources = SqlaTable.query_datasources_by_name(
database,
table_.table,
schema=table_.schema,
catalog=table_.catalog,
)
for datasource_ in datasources:
if self.can_access(
"datasource_access",
datasource_.perm,
) or self.is_owner(datasource_):
# access to any datasource is sufficient
break
else:
denied.add(table_)
if denied:
raise SupersetSecurityException(

View File

@ -119,7 +119,11 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[...,
def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapped_f(*args: Any, **kwargs: Any) -> Any:
if not kwargs.get("cache", True):
should_cache = kwargs.pop("cache", True)
force = kwargs.pop("force", False)
cache_timeout = kwargs.pop("cache_timeout", 0)
if not should_cache:
return f(*args, **kwargs)
# format the key using args/kwargs passed to the decorated function
@ -129,10 +133,10 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[...,
cache_key = key.format(**bound_args.arguments)
obj = cache.get(cache_key)
if not kwargs.get("force") and obj is not None:
if not force and obj is not None:
return obj
obj = f(*args, **kwargs)
cache.set(cache_key, obj, timeout=kwargs.get("cache_timeout", 0))
cache.set(cache_key, obj, timeout=cache_timeout)
return obj
return wrapped_f

View File

@ -679,11 +679,13 @@ def generic_find_uq_constraint_name(
def get_datasource_full_name(
database_name: str, datasource_name: str, schema: str | None = None
database_name: str,
datasource_name: str,
catalog: str | None = None,
schema: str | None = None,
) -> str:
if not schema:
return f"[{database_name}].[{datasource_name}]"
return f"[{database_name}].[{schema}].[{datasource_name}]"
parts = [database_name, catalog, schema, datasource_name]
return ".".join([f"[{part}]" for part in parts if part])
def validate_json(obj: bytes | bytearray | str) -> None:
@ -1051,8 +1053,8 @@ def merge_extra_form_data(form_data: dict[str, Any]) -> None:
"adhoc_filters", []
)
adhoc_filters.extend(
{"isExtra": True, **fltr} # type: ignore
for fltr in append_adhoc_filters
{"isExtra": True, **adhoc_filter} # type: ignore
for adhoc_filter in append_adhoc_filters
)
if append_filters:
for key, value in form_data.items():
@ -1502,6 +1504,7 @@ def shortid() -> str:
class DatasourceName(NamedTuple):
table: str
schema: str
catalog: str | None = None
def get_stacktrace() -> str | None:

View File

@ -32,10 +32,12 @@ def get_dataset_access_filters(
database_ids = security_manager.get_accessible_databases()
perms = security_manager.user_view_menu_names("datasource_access")
schema_perms = security_manager.user_view_menu_names("schema_access")
catalog_perms = security_manager.user_view_menu_names("catalog_access")
return or_(
Database.id.in_(database_ids),
base_model.perm.in_(perms),
base_model.catalog_perm.in_(catalog_perms),
base_model.schema_perm.in_(schema_perms),
*args,
)

View File

@ -211,11 +211,29 @@ class DatabaseMixin:
utils.parse_ssl_cert(database.server_cert)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
security_manager.add_permission_view_menu("database_access", database.perm)
# adding a new database we always want to force refresh schema list
for schema in database.get_all_schema_names():
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
)
# add catalog/schema permissions
if database.db_engine_spec.supports_catalog:
catalogs = database.get_all_catalog_names()
for catalog in catalogs:
security_manager.add_permission_view_menu(
"catalog_access",
security_manager.get_catalog_perm(database.database_name, catalog),
)
else:
# add a dummy catalog for DBs that don't support them
catalogs = [None]
for catalog in catalogs:
for schema in database.get_all_schema_names(catalog=catalog):
security_manager.add_permission_view_menu(
"schema_access",
security_manager.get_schema_perm(
database.database_name,
catalog,
schema,
),
)
def pre_add(self, database: Database) -> None:
self._pre_add_update(database)

View File

@ -36,7 +36,6 @@ from sqlalchemy.exc import DBAPIError
from sqlalchemy.sql import func
from superset import db, security_manager
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError # noqa: F401
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe # noqa: F401
@ -296,14 +295,14 @@ class TestDatabaseApi(SupersetTestCase):
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_ssh_tunnel(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test create with SSH Tunnel
@ -344,14 +343,14 @@ class TestDatabaseApi(SupersetTestCase):
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_missing_port_raises_error(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
@ -397,15 +396,15 @@ class TestDatabaseApi(SupersetTestCase):
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_with_ssh_tunnel(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test update Database with SSH Tunnel
@ -458,15 +457,15 @@ class TestDatabaseApi(SupersetTestCase):
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_with_missing_port_raises_error(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
@ -523,16 +522,16 @@ class TestDatabaseApi(SupersetTestCase):
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_delete_ssh_tunnel(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_delete_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_delete_is_feature_enabled,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test deleting a SSH tunnel via Database update
@ -606,15 +605,15 @@ class TestDatabaseApi(SupersetTestCase):
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_ssh_tunnel_via_database_api(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test update SSH Tunnel via Database API
@ -684,15 +683,15 @@ class TestDatabaseApi(SupersetTestCase):
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
@mock.patch("superset.commands.database.create.is_feature_enabled")
def test_cascade_delete_ssh_tunnel(
self,
mock_test_connection_database_command_run,
mock_get_all_schema_names,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_test_connection_database_command_run,
):
"""
Database API: SSH Tunnel gets deleted if Database gets deleted
@ -739,16 +738,16 @@ class TestDatabaseApi(SupersetTestCase):
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
@mock.patch("superset.extensions.db.session.rollback")
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self,
mock_rollback,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
mock_rollback,
):
"""
Database API: Test rollback is called if SSH Tunnel creation fails
@ -788,14 +787,14 @@ class TestDatabaseApi(SupersetTestCase):
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_get_database_returns_related_ssh_tunnel(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test GET Database returns its related SSH Tunnel
@ -842,13 +841,13 @@ class TestDatabaseApi(SupersetTestCase):
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception(
self,
mock_test_connection_database_command_run,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_test_connection_database_command_run,
):
"""
Database API: Test raises SSHTunneling feature flag not enabled
@ -1987,13 +1986,13 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
self.assertEqual(schemas, set(response["result"]))
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
self.assertEqual(schemas, set(response["result"]))
def test_database_schemas_not_found(self):
"""

View File

@ -1095,7 +1095,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
@patch("superset.daos.database.DatabaseDAO.find_by_id")
def test_database_tables_list_with_unknown_database(self, mock_find_by_id):
mock_find_by_id.return_value = None
command = TablesDatabaseCommand(1, "test", False)
command = TablesDatabaseCommand(1, None, "test", False)
with pytest.raises(DatabaseNotFoundError) as excinfo:
command.run()
@ -1115,7 +1115,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
mock_can_access_database.side_effect = SupersetException("Test Error")
mock_g.user = security_manager.find_user("admin")
command = TablesDatabaseCommand(database.id, "main", False)
command = TablesDatabaseCommand(database.id, None, "main", False)
with pytest.raises(SupersetException) as excinfo:
command.run()
assert str(excinfo.value) == "Test Error"
@ -1131,7 +1131,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
mock_can_access_database.side_effect = Exception("Test Error")
mock_g.user = security_manager.find_user("admin")
command = TablesDatabaseCommand(database.id, "main", False)
command = TablesDatabaseCommand(database.id, None, "main", False)
with pytest.raises(DatabaseTablesUnexpectedError) as excinfo:
command.run()
assert (
@ -1154,7 +1154,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
if database.backend == "postgresql" or database.backend == "mysql":
return
command = TablesDatabaseCommand(database.id, schema_name, False)
command = TablesDatabaseCommand(database.id, None, schema_name, False)
result = command.run()
assert result["count"] > 0

View File

@ -531,7 +531,7 @@ def test_get_catalog_names(app_context: AppContext) -> None:
return
with database.get_inspector() as inspector:
assert PostgresEngineSpec.get_catalog_names(database, inspector) == [
assert PostgresEngineSpec.get_catalog_names(database, inspector) == {
"postgres",
"superset",
]
}

View File

@ -343,20 +343,20 @@ class TestDatabaseModel(SupersetTestCase):
main_db = get_example_database()
if main_db.backend == "mysql":
df = main_db.get_df("SELECT 1", None)
df = main_db.get_df("SELECT 1", None, None)
self.assertEqual(df.iat[0, 0], 1)
df = main_db.get_df("SELECT 1;", None)
df = main_db.get_df("SELECT 1;", None, None)
self.assertEqual(df.iat[0, 0], 1)
def test_multi_statement(self):
main_db = get_example_database()
if main_db.backend == "mysql":
df = main_db.get_df("USE superset; SELECT 1", None)
df = main_db.get_df("USE superset; SELECT 1", None, None)
self.assertEqual(df.iat[0, 0], 1)
df = main_db.get_df("USE superset; SELECT ';';", None)
df = main_db.get_df("USE superset; SELECT ';';", None, None)
self.assertEqual(df.iat[0, 0], ";")
@mock.patch("superset.models.core.create_engine")

View File

@ -898,7 +898,7 @@ class TestRolePermission(SupersetTestCase):
db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one()
)
self.assertEqual(changed_table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})")
self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541
self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541
# Test Chart permission changed
slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one()
@ -956,12 +956,12 @@ class TestRolePermission(SupersetTestCase):
db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one()
)
self.assertEqual(changed_table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})")
self.assertEqual(changed_table1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # noqa: F541
self.assertEqual(changed_table1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541
# Test Chart schema permission changed
slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one()
self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})")
self.assertEqual(slice1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # noqa: F541
self.assertEqual(slice1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541
# cleanup
db.session.delete(slice1)
@ -1069,7 +1069,7 @@ class TestRolePermission(SupersetTestCase):
self.assertEqual(
changed_table1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})"
)
self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541
self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541
# Test Chart permission changed
slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one()
@ -1158,9 +1158,9 @@ class TestRolePermission(SupersetTestCase):
with self.client.application.test_request_context():
database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user(
database, ["1", "2", "3"]
database, None, {"1", "2", "3"}
)
self.assertEqual(schemas, ["1", "2", "3"]) # no changes
self.assertEqual(schemas, {"1", "2", "3"}) # no changes
@patch("superset.utils.core.g")
@patch("superset.security.manager.g")
@ -1171,10 +1171,10 @@ class TestRolePermission(SupersetTestCase):
with self.client.application.test_request_context():
database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user(
database, ["1", "2", "3"]
database, None, {"1", "2", "3"}
)
# temp_schema is not passed in the params
self.assertEqual(schemas, ["1"])
self.assertEqual(schemas, {"1"})
delete_schema_perm("[examples].[1]")
def test_schemas_accessible_by_user_datasource_access(self):
@ -1183,9 +1183,9 @@ class TestRolePermission(SupersetTestCase):
with self.client.application.test_request_context():
with override_user(security_manager.find_user("gamma")):
schemas = security_manager.get_schemas_accessible_by_user(
database, ["temp_schema", "2", "3"]
database, None, {"temp_schema", "2", "3"}
)
self.assertEqual(schemas, ["temp_schema"])
self.assertEqual(schemas, {"temp_schema"})
def test_schemas_accessible_by_user_datasource_and_schema_access(self):
# User has schema access to the datasource temp_schema.wb_health_population in examples DB.
@ -1194,9 +1194,9 @@ class TestRolePermission(SupersetTestCase):
database = get_example_database()
with override_user(security_manager.find_user("gamma")):
schemas = security_manager.get_schemas_accessible_by_user(
database, ["temp_schema", "2", "3"]
database, None, {"temp_schema", "2", "3"}
)
self.assertEqual(schemas, ["temp_schema", "2"])
self.assertEqual(schemas, {"temp_schema", "2"})
vm = security_manager.find_permission_view_menu(
"schema_access", "[examples].[2]"
)

View File

@ -284,9 +284,15 @@ class TestSqlLab(SupersetTestCase):
# sqlite doesn't support database creation
return
catalog = examples_db.get_default_catalog()
sqllab_test_db_schema_permission_view = (
security_manager.add_permission_view_menu(
"schema_access", f"[{examples_db.name}].[{CTAS_SCHEMA_NAME}]"
"schema_access",
security_manager.get_schema_perm(
examples_db.name,
catalog,
CTAS_SCHEMA_NAME,
),
)
)
schema_perm_role = security_manager.add_role("SchemaPermission")

View File

@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.database.create import CreateDatabaseCommand
from superset.extensions import security_manager
@pytest.fixture()
def database_with_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database with catalogs and schemas.
"""
mocker.patch("superset.commands.database.create.db")
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
database = mocker.MagicMock()
database.database_name = "test_database"
database.db_engine_spec.__name__ = "test_engine"
database.db_engine_spec.supports_catalog = True
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
database.get_all_schema_names.side_effect = [
{"schema1", "schema2"},
{"schema3", "schema4"},
]
DatabaseDAO = mocker.patch("superset.commands.database.create.DatabaseDAO")
DatabaseDAO.create.return_value = database
return database
@pytest.fixture()
def database_without_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database without catalogs.
"""
mocker.patch("superset.commands.database.create.db")
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
database = mocker.MagicMock()
database.database_name = "test_database"
database.db_engine_spec.__name__ = "test_engine"
database.db_engine_spec.supports_catalog = False
database.get_all_schema_names.return_value = ["schema1", "schema2"]
DatabaseDAO = mocker.patch("superset.commands.database.create.DatabaseDAO")
DatabaseDAO.create.return_value = database
return database
def test_create_permissions_with_catalog(
mocker: MockerFixture,
database_with_catalog: MockerFixture,
) -> None:
"""
Test that permissions are created when a database with a catalog is created.
"""
add_permission_view_menu = mocker.patch.object(
security_manager,
"add_permission_view_menu",
)
CreateDatabaseCommand(
{
"database_name": "test_database",
"sqlalchemy_uri": "sqlite://",
}
).run()
add_permission_view_menu.assert_has_calls(
[
mocker.call("catalog_access", "[test_database].[catalog1]"),
mocker.call("catalog_access", "[test_database].[catalog2]"),
mocker.call("schema_access", "[test_database].[catalog1].[schema1]"),
mocker.call("schema_access", "[test_database].[catalog1].[schema2]"),
mocker.call("schema_access", "[test_database].[catalog2].[schema3]"),
mocker.call("schema_access", "[test_database].[catalog2].[schema4]"),
],
any_order=True,
)
def test_create_permissions_without_catalog(
mocker: MockerFixture,
database_without_catalog: MockerFixture,
) -> None:
"""
Test that permissions are created when a database without a catalog is created.
"""
add_permission_view_menu = mocker.patch.object(
security_manager,
"add_permission_view_menu",
)
CreateDatabaseCommand(
{
"database_name": "test_database",
"sqlalchemy_uri": "sqlite://",
}
).run()
add_permission_view_menu.assert_has_calls(
[
mocker.call("schema_access", "[test_database].[schema1]"),
mocker.call("schema_access", "[test_database].[schema2]"),
],
any_order=True,
)

View File

@ -0,0 +1,203 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.database.tables import TablesDatabaseCommand
from superset.extensions import security_manager
from superset.utils.core import DatasourceName
@pytest.fixture()
def database_with_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database with catalogs and schemas.
"""
mocker.patch("superset.commands.database.tables.db")
database = mocker.MagicMock()
database.database_name = "test_database"
database.get_all_table_names_in_schema.return_value = [
DatasourceName("table1", "schema1", "catalog1"),
DatasourceName("table2", "schema1", "catalog1"),
]
database.get_all_view_names_in_schema.return_value = [
DatasourceName("view1", "schema1", "catalog1"),
]
DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO")
DatabaseDAO.find_by_id.return_value = database
return database
@pytest.fixture()
def database_without_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database without catalogs but with schemas.
"""
mocker.patch("superset.commands.database.tables.db")
database = mocker.MagicMock()
database.database_name = "test_database"
database.get_all_table_names_in_schema.return_value = [
DatasourceName("table1", "schema1"),
DatasourceName("table2", "schema1"),
]
database.get_all_view_names_in_schema.return_value = [
DatasourceName("view1", "schema1"),
]
DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO")
DatabaseDAO.find_by_id.return_value = database
return database
def test_tables_with_catalog(
mocker: MockerFixture,
database_with_catalog: MockerFixture,
) -> None:
"""
Test that permissions are created when a database with a catalog is created.
"""
get_datasources_accessible_by_user = mocker.patch.object(
security_manager,
"get_datasources_accessible_by_user",
side_effect=[
{
DatasourceName("table1", "schema1", "catalog1"),
DatasourceName("table2", "schema1", "catalog1"),
},
{DatasourceName("view1", "schema1", "catalog1")},
],
)
db = mocker.patch("superset.commands.database.tables.db")
table = mocker.MagicMock()
table.name = "table1"
table.extra_dict = {"foo": "bar"}
db.session.query().filter().options().all.return_value = [table]
payload = TablesDatabaseCommand(1, "catalog1", "schema1", False).run()
assert payload == {
"count": 3,
"result": [
{"value": "table1", "type": "table", "extra": {"foo": "bar"}},
{"value": "table2", "type": "table", "extra": None},
{"value": "view1", "type": "view"},
],
}
get_datasources_accessible_by_user.assert_has_calls(
[
mocker.call(
database=database_with_catalog,
catalog="catalog1",
schema="schema1",
datasource_names=[
DatasourceName("table1", "schema1", "catalog1"),
DatasourceName("table2", "schema1", "catalog1"),
],
),
mocker.call(
database=database_with_catalog,
catalog="catalog1",
schema="schema1",
datasource_names=[
DatasourceName("view1", "schema1", "catalog1"),
],
),
],
)
database_with_catalog.get_all_table_names_in_schema.assert_called_with(
catalog="catalog1",
schema="schema1",
force=False,
cache=database_with_catalog.table_cache_enabled,
cache_timeout=database_with_catalog.table_cache_timeout,
)
def test_tables_without_catalog(
mocker: MockerFixture,
database_without_catalog: MockerFixture,
) -> None:
"""
Test that permissions are created when a database without a catalog is created.
"""
get_datasources_accessible_by_user = mocker.patch.object(
security_manager,
"get_datasources_accessible_by_user",
side_effect=[
{
DatasourceName("table1", "schema1"),
DatasourceName("table2", "schema1"),
},
{DatasourceName("view1", "schema1")},
],
)
db = mocker.patch("superset.commands.database.tables.db")
table = mocker.MagicMock()
table.name = "table1"
table.extra_dict = {"foo": "bar"}
db.session.query().filter().options().all.return_value = [table]
payload = TablesDatabaseCommand(1, None, "schema1", False).run()
assert payload == {
"count": 3,
"result": [
{"value": "table1", "type": "table", "extra": {"foo": "bar"}},
{"value": "table2", "type": "table", "extra": None},
{"value": "view1", "type": "view"},
],
}
get_datasources_accessible_by_user.assert_has_calls(
[
mocker.call(
database=database_without_catalog,
catalog=None,
schema="schema1",
datasource_names=[
DatasourceName("table1", "schema1"),
DatasourceName("table2", "schema1"),
],
),
mocker.call(
database=database_without_catalog,
catalog=None,
schema="schema1",
datasource_names=[
DatasourceName("view1", "schema1"),
],
),
],
)
database_without_catalog.get_all_table_names_in_schema.assert_called_with(
catalog=None,
schema="schema1",
force=False,
cache=database_without_catalog.table_cache_enabled,
cache_timeout=database_without_catalog.table_cache_timeout,
)

View File

@ -0,0 +1,272 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.commands.database.update import UpdateDatabaseCommand
from superset.extensions import security_manager
@pytest.fixture()
def database_with_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database with catalogs and schemas.
"""
mocker.patch("superset.commands.database.update.db")
database = mocker.MagicMock()
database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine"
database.db_engine_spec.supports_catalog = True
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
database.get_all_schema_names.side_effect = [
["schema1", "schema2"],
["schema3", "schema4"],
]
database.get_default_catalog.return_value = "catalog2"
return database
@pytest.fixture()
def database_without_catalog(mocker: MockerFixture) -> MagicMock:
"""
Mock a database without catalogs.
"""
mocker.patch("superset.commands.database.update.db")
database = mocker.MagicMock()
database.database_name = "my_db"
database.db_engine_spec.__name__ = "test_engine"
database.db_engine_spec.supports_catalog = False
database.get_all_schema_names.return_value = ["schema1", "schema2"]
return database
def test_update_with_catalog(
mocker: MockerFixture,
database_with_catalog: MockerFixture,
) -> None:
"""
Test that permissions are updated correctly.
In this test, the database has two catalogs with two schemas each:
- catalog1
- schema1
- schema2
- catalog2
- schema3
- schema4
When update is called, only `catalog2.schema3` has permissions associated with it,
so `catalog1.*` and `catalog2.schema4` are added.
"""
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
DatabaseDAO.find_by_id.return_value = database_with_catalog
DatabaseDAO.update.return_value = database_with_catalog
find_permission_view_menu = mocker.patch.object(
security_manager,
"find_permission_view_menu",
)
find_permission_view_menu.side_effect = [
None, # first catalog is new
"[my_db].[catalog2]", # second catalog already exists
"[my_db].[catalog2].[schema3]", # first schema already exists
None, # second schema is new
# these are called when checking for existing perms in [db].[schema] format
None,
None,
]
add_permission_view_menu = mocker.patch.object(
security_manager,
"add_permission_view_menu",
)
UpdateDatabaseCommand(1, {}).run()
add_permission_view_menu.assert_has_calls(
[
# first catalog is added with all schemas
mocker.call("catalog_access", "[my_db].[catalog1]"),
mocker.call("schema_access", "[my_db].[catalog1].[schema1]"),
mocker.call("schema_access", "[my_db].[catalog1].[schema2]"),
# second catalog already exists, only `schema4` is added
mocker.call("schema_access", "[my_db].[catalog2].[schema4]"),
],
)
def test_update_without_catalog(
mocker: MockerFixture,
database_without_catalog: MockerFixture,
) -> None:
"""
Test that permissions are updated correctly.
In this test, the database has no catalogs and two schemas:
- schema1
- schema2
When update is called, only `schema2` has permissions associated with it, so `schema1`
is added.
"""
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
DatabaseDAO.find_by_id.return_value = database_without_catalog
DatabaseDAO.update.return_value = database_without_catalog
find_permission_view_menu = mocker.patch.object(
security_manager,
"find_permission_view_menu",
)
find_permission_view_menu.side_effect = [
None, # schema1 has no permissions
"[my_db].[schema2]", # second schema already exists
]
add_permission_view_menu = mocker.patch.object(
security_manager,
"add_permission_view_menu",
)
UpdateDatabaseCommand(1, {}).run()
add_permission_view_menu.assert_called_with(
"schema_access",
"[my_db].[schema1]",
)
def test_rename_with_catalog(
mocker: MockerFixture,
database_with_catalog: MockerFixture,
) -> None:
"""
Test that permissions are renamed correctly.
In this test, the database has two catalogs with two schemas each:
- catalog1
- schema1
- schema2
- catalog2
- schema3
- schema4
When update is called, only `catalog2.schema3` has permissions associated with it,
so `catalog1.*` and `catalog2.schema4` are added. Additionally, the database has
been renamed from `my_db` to `my_other_db`.
"""
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
original_database = mocker.MagicMock()
original_database.database_name = "my_db"
DatabaseDAO.find_by_id.return_value = original_database
database_with_catalog.database_name = "my_other_db"
DatabaseDAO.update.return_value = database_with_catalog
DatabaseDAO.get_datasets.return_value = []
find_permission_view_menu = mocker.patch.object(
security_manager,
"find_permission_view_menu",
)
catalog2_pvm = mocker.MagicMock()
catalog2_schema3_pvm = mocker.MagicMock()
find_permission_view_menu.side_effect = [
# these are called when adding the permissions:
None, # first catalog is new
"[my_db].[catalog2]", # second catalog already exists
"[my_db].[catalog2].[schema3]", # first schema already exists
None, # second schema is new
# these are called when renaming the permissions:
catalog2_pvm, # old [my_db].[catalog2]
catalog2_schema3_pvm, # old [my_db].[catalog2].[schema3]
None, # [my_db].[catalog2].[schema4] doesn't exist
]
add_permission_view_menu = mocker.patch.object(
security_manager,
"add_permission_view_menu",
)
UpdateDatabaseCommand(1, {}).run()
add_permission_view_menu.assert_has_calls(
[
# first catalog is added with all schemas with the new DB name
mocker.call("catalog_access", "[my_other_db].[catalog1]"),
mocker.call("schema_access", "[my_other_db].[catalog1].[schema1]"),
mocker.call("schema_access", "[my_other_db].[catalog1].[schema2]"),
# second catalog already exists, only `schema4` is added
mocker.call("schema_access", "[my_other_db].[catalog2].[schema4]"),
],
)
assert catalog2_pvm.view_menu.name == "[my_other_db].[catalog2]"
assert catalog2_schema3_pvm.view_menu.name == "[my_other_db].[catalog2].[schema3]"
def test_rename_without_catalog(
mocker: MockerFixture,
database_without_catalog: MockerFixture,
) -> None:
"""
Test that permissions are renamed correctly.
In this test, the database has no catalogs and two schemas:
- schema1
- schema2
When update is called, only `schema2` has permissions associated with it, so `schema1`
is added. Additionally, the database has been renamed from `my_db` to `my_other_db`.
"""
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
original_database = mocker.MagicMock()
original_database.database_name = "my_db"
DatabaseDAO.find_by_id.return_value = original_database
database_without_catalog.database_name = "my_other_db"
DatabaseDAO.update.return_value = database_without_catalog
DatabaseDAO.get_datasets.return_value = []
find_permission_view_menu = mocker.patch.object(
security_manager,
"find_permission_view_menu",
)
schema2_pvm = mocker.MagicMock()
find_permission_view_menu.side_effect = [
None, # schema1 has no permissions
"[my_db].[schema2]", # second schema already exists
None, # [my_db].[schema1] doesn't exist
schema2_pvm, # old [my_db].[schema2]
]
add_permission_view_menu = mocker.patch.object(
security_manager,
"add_permission_view_menu",
)
UpdateDatabaseCommand(1, {}).run()
add_permission_view_menu.assert_called_with(
"schema_access",
"[my_other_db].[schema1]",
)
assert schema2_pvm.view_menu.name == "[my_other_db].[schema2]"

View File

@ -88,6 +88,8 @@ def app(request: SubRequest) -> Iterator[SupersetApp]:
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
app.config["TESTING"] = True
app.config["RATELIMIT_ENABLED"] = False
app.config["CACHE_CONFIG"] = {}
app.config["DATA_CACHE_CONFIG"] = {}
# loop over extra configs passed in by tests
# and update the app config

View File

@ -17,9 +17,11 @@
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import OAuth2RedirectError
from superset.models.core import Database
from superset.superset_typing import QueryObjectDict
@ -64,3 +66,124 @@ def test_query_bubbles_errors(mocker: MockerFixture) -> None:
}
with pytest.raises(OAuth2RedirectError):
sqla_table.query(query_obj)
def test_permissions_without_catalog() -> None:
"""
Test permissions when the table has no catalog.
"""
database = Database(database_name="my_db")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=database,
schema="schema1",
catalog=None,
id=1,
)
assert sqla_table.get_perm() == "[my_db].[my_sqla_table](id:1)"
assert sqla_table.get_catalog_perm() is None
assert sqla_table.get_schema_perm() == "[my_db].[schema1]"
def test_permissions_with_catalog() -> None:
"""
Test permissions when the table with a catalog set.
"""
database = Database(database_name="my_db")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=database,
schema="schema1",
catalog="db1",
id=1,
)
assert sqla_table.get_perm() == "[my_db].[my_sqla_table](id:1)"
assert sqla_table.get_catalog_perm() == "[my_db].[db1]"
assert sqla_table.get_schema_perm() == "[my_db].[db1].[schema1]"
def test_query_datasources_by_name(mocker: MockerFixture) -> None:
"""
Test the `query_datasources_by_name` method.
"""
db = mocker.patch("superset.connectors.sqla.models.db")
database = Database(database_name="my_db", id=1)
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=database,
)
sqla_table.query_datasources_by_name(database, "my_table")
db.session.query().filter_by.assert_called_with(
database_id=1,
table_name="my_table",
)
sqla_table.query_datasources_by_name(database, "my_table", "db1", "schema1")
db.session.query().filter_by.assert_called_with(
database_id=1,
table_name="my_table",
catalog="db1",
schema="schema1",
)
def test_query_datasources_by_permissions(mocker: MockerFixture) -> None:
"""
Test the `query_datasources_by_permissions` method.
"""
db = mocker.patch("superset.connectors.sqla.models.db")
engine = create_engine("sqlite://")
database = Database(database_name="my_db", id=1)
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=database,
)
sqla_table.query_datasources_by_permissions(database, set(), set(), set())
db.session.query().filter_by.assert_called_with(database_id=1)
clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == ""
def test_query_datasources_by_permissions_with_catalog_schema(
mocker: MockerFixture,
) -> None:
"""
Test the `query_datasources_by_permissions` method passing a catalog and schema.
"""
db = mocker.patch("superset.connectors.sqla.models.db")
engine = create_engine("sqlite://")
database = Database(database_name="my_db", id=1)
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=database,
)
sqla_table.query_datasources_by_permissions(
database,
{"[my_db].[table1](id:1)"},
{"[my_db].[db1]"},
# pass as list to have deterministic order for test
["[my_db].[db1].[schema1]", "[my_other_db].[schema]"], # type: ignore
)
clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == (
"tables.perm IN ('[my_db].[table1](id:1)') OR "
"tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR "
"tables.catalog_perm IN ('[my_db].[db1]')"
)

View File

@ -2066,3 +2066,101 @@ def test_table_extra_metadata_unauthorized(
}
]
}
def test_catalogs(
mocker: MockFixture,
client: Any,
full_api_access: None,
) -> None:
"""
Test the `catalogs` endpoint.
"""
database = mocker.MagicMock()
database.get_all_catalog_names.return_value = {"db1", "db2"}
DatabaseDAO = mocker.patch("superset.databases.api.DatabaseDAO")
DatabaseDAO.find_by_id.return_value = database
security_manager = mocker.patch(
"superset.databases.api.security_manager",
new=mocker.MagicMock(),
)
security_manager.get_catalogs_accessible_by_user.return_value = {"db2"}
response = client.get("/api/v1/database/1/catalogs/")
assert response.status_code == 200
assert response.json == {"result": ["db2"]}
database.get_all_catalog_names.assert_called_with(
cache=database.catalog_cache_enabled,
cache_timeout=database.catalog_cache_timeout,
force=False,
)
security_manager.get_catalogs_accessible_by_user.assert_called_with(
database,
{"db1", "db2"},
)
response = client.get("/api/v1/database/1/catalogs/?q=(force:!t)")
database.get_all_catalog_names.assert_called_with(
cache=database.catalog_cache_enabled,
cache_timeout=database.catalog_cache_timeout,
force=True,
)
def test_schemas(
mocker: MockFixture,
client: Any,
full_api_access: None,
) -> None:
"""
Test the `schemas` endpoint.
"""
from superset.databases.api import DatabaseRestApi
database = mocker.MagicMock()
database.get_all_schema_names.return_value = {"schema1", "schema2"}
datamodel = mocker.patch.object(DatabaseRestApi, "datamodel")
datamodel.get.return_value = database
security_manager = mocker.patch(
"superset.databases.api.security_manager",
new=mocker.MagicMock(),
)
security_manager.get_schemas_accessible_by_user.return_value = {"schema2"}
response = client.get("/api/v1/database/1/schemas/")
assert response.status_code == 200
assert response.json == {"result": ["schema2"]}
database.get_all_schema_names.assert_called_with(
catalog=None,
cache=database.schema_cache_enabled,
cache_timeout=database.schema_cache_timeout,
force=False,
)
security_manager.get_schemas_accessible_by_user.assert_called_with(
database,
None,
{"schema1", "schema2"},
)
response = client.get("/api/v1/database/1/schemas/?q=(force:!t)")
database.get_all_schema_names.assert_called_with(
catalog=None,
cache=database.schema_cache_enabled,
cache_timeout=database.schema_cache_timeout,
force=True,
)
response = client.get("/api/v1/database/1/schemas/?q=(force:!t,catalog:catalog2)")
database.get_all_schema_names.assert_called_with(
catalog="catalog2",
cache=database.schema_cache_enabled,
cache_timeout=database.schema_cache_timeout,
force=True,
)
security_manager.get_schemas_accessible_by_user.assert_called_with(
database,
"catalog2",
{"schema1", "schema2"},
)

View File

@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_appbuilder.models.sqla.interface import SQLAInterface
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from superset.databases.filters import can_access_databases, DatabaseFilter
from superset.extensions import security_manager
def test_can_access_databases(mocker: MockerFixture) -> None:
"""
Test the `can_access_databases` function.
"""
mocker.patch.object(
security_manager,
"user_view_menu_names",
side_effect=[
{
"[my_db].[examples].[public].[table1](id:1)",
"[my_other_db].[examples].[public].[table1](id:2)",
},
{"[my_db].(id:42)", "[my_other_db].(id:43)"},
{"[my_db].[examples]", "[my_db].[other]"},
{
"[my_db].[examples].[information_schema]",
"[my_db].[other].[secret]",
"[third_db].[schema]",
},
],
)
assert can_access_databases("datasource_access") == {"my_db", "my_other_db"}
assert can_access_databases("database_access") == {"my_db", "my_other_db"}
assert can_access_databases("catalog_access") == {"my_db"}
assert can_access_databases("schema_access") == {"my_db", "third_db"}
def test_database_filter_full_db_access(mocker: MockerFixture) -> None:
"""
Test the `DatabaseFilter` class when the user has full database access.
In this case the query should be returned unmodified.
"""
from superset.models.core import Database
current_app = mocker.patch("superset.databases.filters.current_app")
current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False}
mocker.patch.object(security_manager, "can_access_all_databases", return_value=True)
engine = create_engine("sqlite://")
Session = sessionmaker(bind=engine)
session = Session()
query = session.query(Database)
filter_ = DatabaseFilter("id", SQLAInterface(Database))
filtered_query = filter_.apply(query, None)
assert filtered_query == query
def test_database_filter(mocker: MockerFixture) -> None:
"""
Test the `DatabaseFilter` class with specific permissions.
"""
from superset.models.core import Database
current_app = mocker.patch("superset.databases.filters.current_app")
current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False}
mocker.patch.object(
security_manager,
"can_access_all_databases",
return_value=False,
)
mocker.patch.object(
security_manager,
"user_view_menu_names",
side_effect=[
# return lists instead of sets to ensure order
["[my_db].(id:42)", "[my_other_db].(id:43)"],
["[my_db].[examples]", "[my_db].[other]"],
[
"[my_db].[examples].[information_schema]",
"[my_db].[other].[secret]",
"[third_db].[schema]",
],
[
"[my_db].[examples].[public].[table1](id:1)",
"[my_other_db].[examples].[public].[table1](id:2)",
],
],
)
engine = create_engine("sqlite://")
Session = sessionmaker(bind=engine)
session = Session()
query = session.query(Database)
filter_ = DatabaseFilter("id", SQLAInterface(Database))
filtered_query = filter_.apply(query, None)
compiled_query = filtered_query.statement.compile(
engine,
compile_kwargs={"literal_binds": True},
)
space = " " # pre-commit removes trailing spaces...
assert (
str(compiled_query)
== f"""SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk{space}
FROM dbs{space}
WHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')"""
)

View File

@ -301,3 +301,13 @@ def test_extra_table_metadata(mocker: MockFixture) -> None:
)
warnings.warn.assert_called()
def test_get_default_catalog(mocker: MockFixture) -> None:
"""
Test the `get_default_catalog` method.
"""
from superset.db_engine_specs.base import BaseEngineSpec
database = mocker.MagicMock()
assert BaseEngineSpec.get_default_catalog(database) is None

View File

@ -175,3 +175,33 @@ SELECT * FROM some_table;
str(excinfo.value)
== "Users are not allowed to set a search path for security reasons."
)
def test_adjust_engine_params() -> None:
"""
Test `adjust_engine_params`.
The method can be used to adjust the catalog (database) dynamically.
"""
from superset.db_engine_specs.postgres import PostgresEngineSpec
adjusted = PostgresEngineSpec.adjust_engine_params(
make_url("postgresql://user:password@host:5432/dev"),
{},
catalog="prod",
)
assert adjusted == (make_url("postgresql://user:password@host:5432/prod"), {})
def test_get_default_catalog() -> None:
"""
Test `get_default_catalog`.
"""
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.models.core import Database
database = Database(
database_name="postgres",
sqlalchemy_uri="postgresql://user:password@host:5432/dev",
)
assert PostgresEngineSpec.get_default_catalog(database) == "dev"

View File

@ -272,6 +272,7 @@ def test_query_no_access(mocker: MockFixture, client) -> None:
from superset.models.sql_lab import Query
database = mocker.MagicMock()
database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public"
mocker.patch(
query_find_by_id,

View File

@ -229,3 +229,63 @@ def test_get_prequeries(mocker: MockFixture) -> None:
conn.cursor().execute.assert_has_calls(
[mocker.call("set a=1"), mocker.call("set b=2")]
)
def test_catalog_cache() -> None:
"""
Test the catalog cache.
"""
database = Database(
database_name="db",
sqlalchemy_uri="sqlite://",
extra=json.dumps({"metadata_cache_timeout": {"catalog_cache_timeout": 10}}),
)
assert database.catalog_cache_enabled
assert database.catalog_cache_timeout == 10
def test_get_default_catalog() -> None:
"""
Test the `get_default_catalog` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
assert database.get_default_catalog() == "examples"
def test_get_default_schema(mocker: MockFixture) -> None:
"""
Test the `get_default_schema` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
get_inspector = mocker.patch.object(database, "get_inspector")
with get_inspector() as inspector:
inspector.default_schema_name = "public"
assert database.get_default_schema("examples") == "public"
get_inspector.assert_called_with(catalog="examples")
def test_get_all_catalog_names(mocker: MockFixture) -> None:
"""
Test the `get_all_catalog_names` method.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
get_inspector = mocker.patch.object(database, "get_inspector")
with get_inspector() as inspector:
inspector.bind.execute.return_value = [("examples",), ("other",)]
assert database.get_all_catalog_names(force=True) == {"examples", "other"}
get_inspector.assert_called_with(ssh_tunnel=None)

View File

@ -26,7 +26,10 @@ from superset.connectors.sqla.models import Database, SqlaTable
from superset.exceptions import SupersetSecurityException
from superset.extensions import appbuilder
from superset.models.slice import Slice
from superset.security.manager import query_context_modified, SupersetSecurityManager
from superset.security.manager import (
query_context_modified,
SupersetSecurityManager,
)
from superset.sql_parse import Table
from superset.superset_typing import AdhocMetric
from superset.utils.core import override_user
@ -208,6 +211,7 @@ def test_raise_for_access_query_default_schema(
SqlaTable.query_datasources_by_name.return_value = []
database = mocker.MagicMock()
database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock()
query.database = database
@ -262,6 +266,7 @@ def test_raise_for_access_jinja_sql(mocker: MockFixture, app_context: None) -> N
SqlaTable.query_datasources_by_name.return_value = []
database = mocker.MagicMock()
database.get_default_catalog.return_value = None
database.get_default_schema_for_query.return_value = "public"
query = mocker.MagicMock()
query.database = database
@ -531,3 +536,28 @@ def test_query_context_modified_mixed_chart(mocker: MockFixture) -> None:
}
query_context.queries = [QueryObject(metrics=requested_metrics)] # type: ignore
assert not query_context_modified(query_context)
def test_get_catalog_perm() -> None:
"""
Test the `get_catalog_perm` method.
"""
sm = SupersetSecurityManager(appbuilder)
assert sm.get_catalog_perm("my_db", None) is None
assert sm.get_catalog_perm("my_db", "my_catalog") == "[my_db].[my_catalog]"
def test_get_schema_perm() -> None:
"""
Test the `get_schema_perm` method.
"""
sm = SupersetSecurityManager(appbuilder)
assert sm.get_schema_perm("my_db", None, "my_schema") == "[my_db].[my_schema]"
assert (
sm.get_schema_perm("my_db", "my_catalog", "my_schema")
== "[my_db].[my_catalog].[my_schema]"
)
assert sm.get_schema_perm("my_db", None, None) is None
assert sm.get_schema_perm("my_db", "my_catalog", None) is None

View File

@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
from superset.utils.filters import get_dataset_access_filters
def test_get_dataset_access_filters(mocker: MockerFixture) -> None:
"""
Test the `get_dataset_access_filters` function.
"""
from superset.connectors.sqla.models import SqlaTable
from superset.extensions import security_manager
mocker.patch.object(
security_manager,
"get_accessible_databases",
return_value=[1, 3],
)
mocker.patch.object(
security_manager,
"user_view_menu_names",
side_effect=[
{"[db].[catalog1].[schema1].[table1](id:1)"},
{"[db].[catalog1].[schema2]"},
{"[db].[catalog2]"},
],
)
clause = get_dataset_access_filters(SqlaTable)
engine = create_engine("sqlite://")
compiled_query = clause.compile(engine, compile_kwargs={"literal_binds": True})
assert str(compiled_query) == (
"dbs.id IN (1, 3) "
"OR tables.perm IN ('[db].[catalog1].[schema1].[table1](id:1)') "
"OR tables.catalog_perm IN ('[db].[catalog2]') OR "
"tables.schema_perm IN ('[db].[catalog1].[schema2]')"
)

View File

@ -29,6 +29,7 @@ from superset.utils.core import (
DateColumn,
generic_find_constraint_name,
generic_find_fk_constraint_name,
get_datasource_full_name,
is_test,
normalize_dttm_col,
parse_boolean_string,
@ -369,3 +370,29 @@ def test_generic_find_fk_constraint_none_exist():
)
assert result is None
def test_get_datasource_full_name():
"""
Test the `get_datasource_full_name` function.
This is used to build permissions, so it doesn't really return the datasource full
name. Instead, it returns a fully qualified table name that includes the database
name and schema, with each part wrapped in square brackets.
"""
assert (
get_datasource_full_name("db", "table", "catalog", "schema")
== "[db].[catalog].[schema].[table]"
)
assert get_datasource_full_name("db", "table", None, None) == "[db].[table]"
assert (
get_datasource_full_name("db", "table", None, "schema")
== "[db].[schema].[table]"
)
assert (
get_datasource_full_name("db", "table", "catalog", None)
== "[db].[catalog].[table]"
)

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pytest_mock import MockerFixture
from superset.views.database.mixins import DatabaseMixin
def test_pre_add_update_with_catalog(mocker: MockerFixture) -> None:
"""
Test the `_pre_add_update` method on a DB with catalog support.
"""
from superset.models.core import Database
add_permission_view_menu = mocker.patch(
"superset.views.database.mixins.security_manager.add_permission_view_menu"
)
database = Database(
database_name="my_db",
id=42,
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
mocker.patch.object(
database,
"get_all_catalog_names",
return_value=["examples", "other"],
)
mocker.patch.object(
database,
"get_all_schema_names",
side_effect=[
["public", "information_schema"],
["secret"],
],
)
mixin = DatabaseMixin()
mixin._pre_add_update(database)
add_permission_view_menu.assert_has_calls(
[
mocker.call("database_access", "[my_db].(id:42)"),
mocker.call("catalog_access", "[my_db].[examples]"),
mocker.call("catalog_access", "[my_db].[other]"),
mocker.call("schema_access", "[my_db].[examples].[public]"),
mocker.call("schema_access", "[my_db].[examples].[information_schema]"),
mocker.call("schema_access", "[my_db].[other].[secret]"),
],
any_order=True,
)