feat(SIP-95): permissions for catalogs (#28317)
This commit is contained in:
parent
9a339f08a7
commit
e90246fd1f
|
|
@ -97,12 +97,37 @@ class CreateDatabaseCommand(BaseCommand):
|
|||
|
||||
db.session.commit()
|
||||
|
||||
# adding a new database we always want to force refresh schema list
|
||||
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
|
||||
for schema in schemas:
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||
# add catalog/schema permissions
|
||||
if database.db_engine_spec.supports_catalog:
|
||||
catalogs = database.get_all_catalog_names(
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
for catalog in catalogs:
|
||||
security_manager.add_permission_view_menu(
|
||||
"catalog_access",
|
||||
security_manager.get_catalog_perm(
|
||||
database.database_name, catalog
|
||||
),
|
||||
)
|
||||
else:
|
||||
# add a dummy catalog for DBs that don't support them
|
||||
catalogs = [None]
|
||||
|
||||
for catalog in catalogs:
|
||||
for schema in database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
):
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access",
|
||||
security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
),
|
||||
)
|
||||
|
||||
except (
|
||||
SSHTunnelInvalidError,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
|
|
@ -29,7 +32,6 @@ from superset.daos.database import DatabaseDAO
|
|||
from superset.exceptions import SupersetException
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import DatasourceName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -37,8 +39,15 @@ logger = logging.getLogger(__name__)
|
|||
class TablesDatabaseCommand(BaseCommand):
|
||||
_model: Database
|
||||
|
||||
def __init__(self, db_id: int, schema_name: str, force: bool):
|
||||
def __init__(
|
||||
self,
|
||||
db_id: int,
|
||||
catalog_name: str | None,
|
||||
schema_name: str,
|
||||
force: bool,
|
||||
):
|
||||
self._db_id = db_id
|
||||
self._catalog_name = catalog_name
|
||||
self._schema_name = schema_name
|
||||
self._force = force
|
||||
|
||||
|
|
@ -47,11 +56,11 @@ class TablesDatabaseCommand(BaseCommand):
|
|||
try:
|
||||
tables = security_manager.get_datasources_accessible_by_user(
|
||||
database=self._model,
|
||||
catalog=self._catalog_name,
|
||||
schema=self._schema_name,
|
||||
datasource_names=sorted(
|
||||
DatasourceName(*datasource_name)
|
||||
for datasource_name in self._model.get_all_table_names_in_schema(
|
||||
catalog=None,
|
||||
self._model.get_all_table_names_in_schema(
|
||||
catalog=self._catalog_name,
|
||||
schema=self._schema_name,
|
||||
force=self._force,
|
||||
cache=self._model.table_cache_enabled,
|
||||
|
|
@ -62,11 +71,11 @@ class TablesDatabaseCommand(BaseCommand):
|
|||
|
||||
views = security_manager.get_datasources_accessible_by_user(
|
||||
database=self._model,
|
||||
catalog=self._catalog_name,
|
||||
schema=self._schema_name,
|
||||
datasource_names=sorted(
|
||||
DatasourceName(*datasource_name)
|
||||
for datasource_name in self._model.get_all_view_names_in_schema(
|
||||
catalog=None,
|
||||
self._model.get_all_view_names_in_schema(
|
||||
catalog=self._catalog_name,
|
||||
schema=self._schema_name,
|
||||
force=self._force,
|
||||
cache=self._model.table_cache_enabled,
|
||||
|
|
@ -81,11 +90,15 @@ class TablesDatabaseCommand(BaseCommand):
|
|||
db.session.query(SqlaTable)
|
||||
.filter(
|
||||
SqlaTable.database_id == self._model.id,
|
||||
SqlaTable.catalog == self._catalog_name,
|
||||
SqlaTable.schema == self._schema_name,
|
||||
)
|
||||
.options(
|
||||
load_only(
|
||||
SqlaTable.schema, SqlaTable.table_name, SqlaTable.extra
|
||||
SqlaTable.catalog,
|
||||
SqlaTable.schema,
|
||||
SqlaTable.table_name,
|
||||
SqlaTable.extra,
|
||||
),
|
||||
lazyload(SqlaTable.columns),
|
||||
lazyload(SqlaTable.metrics),
|
||||
|
|
|
|||
|
|
@ -18,10 +18,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import is_feature_enabled, security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
|
|
@ -50,12 +49,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateDatabaseCommand(BaseCommand):
|
||||
_model: Optional[Database]
|
||||
_model: Database | None
|
||||
|
||||
def __init__(self, model_id: int, data: dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
self._model: Database | None = None
|
||||
|
||||
def run(self) -> Model:
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
|
|
@ -85,7 +84,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
ssh_tunnel = self._handle_ssh_tunnel(database)
|
||||
self._refresh_schemas(database, original_database_name, ssh_tunnel)
|
||||
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
|
||||
except SSHTunnelError as ex:
|
||||
# allow exception to bubble for debugbing information
|
||||
raise ex
|
||||
|
|
@ -121,67 +120,200 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
ssh_tunnel_properties,
|
||||
).run()
|
||||
|
||||
def _refresh_schemas(
|
||||
def _get_catalog_names(
|
||||
self,
|
||||
database: Database,
|
||||
original_database_name: str,
|
||||
ssh_tunnel: Optional[SSHTunnel],
|
||||
) -> None:
|
||||
ssh_tunnel: SSHTunnel | None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Add permissions for any new schemas.
|
||||
Helper method to load catalogs.
|
||||
|
||||
This method captures a generic exception, since errors could potentially come
|
||||
from any of the 50+ database drivers we support.
|
||||
"""
|
||||
try:
|
||||
schemas = database.get_all_schema_names(ssh_tunnel=ssh_tunnel)
|
||||
return database.get_all_catalog_names(
|
||||
force=True,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
for schema in schemas:
|
||||
original_vm = security_manager.get_schema_perm(
|
||||
def _get_schema_names(
|
||||
self,
|
||||
database: Database,
|
||||
catalog: str | None,
|
||||
ssh_tunnel: SSHTunnel | None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Helper method to load schemas.
|
||||
|
||||
This method captures a generic exception, since errors could potentially come
|
||||
from any of the 50+ database drivers we support.
|
||||
"""
|
||||
try:
|
||||
return database.get_all_schema_names(
|
||||
force=True,
|
||||
catalog=catalog,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
)
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
raise DatabaseConnectionFailedError() from ex
|
||||
|
||||
def _refresh_catalogs(
|
||||
self,
|
||||
database: Database,
|
||||
original_database_name: str,
|
||||
ssh_tunnel: SSHTunnel | None,
|
||||
) -> None:
|
||||
"""
|
||||
Add permissions for any new catalogs and schemas.
|
||||
"""
|
||||
catalogs = (
|
||||
self._get_catalog_names(database, ssh_tunnel)
|
||||
if database.db_engine_spec.supports_catalog
|
||||
else [None]
|
||||
)
|
||||
|
||||
for catalog in catalogs:
|
||||
schemas = self._get_schema_names(database, catalog, ssh_tunnel)
|
||||
|
||||
if catalog:
|
||||
perm = security_manager.get_catalog_perm(
|
||||
original_database_name,
|
||||
catalog,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"catalog_access",
|
||||
perm,
|
||||
)
|
||||
if not existing_pvm:
|
||||
# new catalog
|
||||
security_manager.add_permission_view_menu(
|
||||
"catalog_access",
|
||||
security_manager.get_catalog_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
),
|
||||
)
|
||||
for schema in schemas:
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access",
|
||||
security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
),
|
||||
)
|
||||
continue
|
||||
|
||||
# add possible new schemas in catalog
|
||||
self._refresh_schemas(
|
||||
database,
|
||||
original_database_name,
|
||||
catalog,
|
||||
schemas,
|
||||
)
|
||||
|
||||
if original_database_name != database.database_name:
|
||||
self._rename_database_in_permissions(
|
||||
database,
|
||||
original_database_name,
|
||||
catalog,
|
||||
schemas,
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _refresh_schemas(
|
||||
self,
|
||||
database: Database,
|
||||
original_database_name: str,
|
||||
catalog: str | None,
|
||||
schemas: set[str],
|
||||
) -> None:
|
||||
"""
|
||||
Add new schemas that don't have permissions yet.
|
||||
"""
|
||||
for schema in schemas:
|
||||
perm = security_manager.get_schema_perm(
|
||||
original_database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
original_vm,
|
||||
perm,
|
||||
)
|
||||
if not existing_pvm:
|
||||
# new schema
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access",
|
||||
security_manager.get_schema_perm(database.database_name, schema),
|
||||
new_name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
continue
|
||||
security_manager.add_permission_view_menu("schema_access", new_name)
|
||||
|
||||
if original_database_name == database.database_name:
|
||||
continue
|
||||
def _rename_database_in_permissions(
|
||||
self,
|
||||
database: Database,
|
||||
original_database_name: str,
|
||||
catalog: str | None,
|
||||
schemas: set[str],
|
||||
) -> None:
|
||||
new_name = security_manager.get_catalog_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
)
|
||||
|
||||
# rename existing schema permission
|
||||
existing_pvm.view_menu.name = security_manager.get_schema_perm(
|
||||
# rename existing catalog permission
|
||||
if catalog:
|
||||
perm = security_manager.get_catalog_perm(
|
||||
original_database_name,
|
||||
catalog,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"catalog_access",
|
||||
perm,
|
||||
)
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu.name = new_name
|
||||
|
||||
for schema in schemas:
|
||||
new_name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
|
||||
# rename existing schema permission
|
||||
perm = security_manager.get_schema_perm(
|
||||
original_database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
perm,
|
||||
)
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu.name = new_name
|
||||
|
||||
# rename permissions on datasets and charts
|
||||
for dataset in DatabaseDAO.get_datasets(
|
||||
database.id,
|
||||
catalog=None,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
):
|
||||
dataset.schema_perm = existing_pvm.view_menu.name
|
||||
dataset.schema_perm = new_name
|
||||
for chart in DatasetDAO.get_related_objects(dataset.id)["charts"]:
|
||||
chart.schema_perm = existing_pvm.view_menu.name
|
||||
|
||||
db.session.commit()
|
||||
chart.schema_perm = new_name
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: list[ValidationError] = []
|
||||
database_name: Optional[str] = self._properties.get("database_name")
|
||||
if database_name:
|
||||
# Check database_name uniqueness
|
||||
if database_name := self._properties.get("database_name"):
|
||||
if not DatabaseDAO.validate_update_uniqueness(
|
||||
self._model_id, database_name
|
||||
self._model_id,
|
||||
database_name,
|
||||
):
|
||||
exceptions.append(DatabaseExistsValidationError())
|
||||
if exceptions:
|
||||
raise DatabaseInvalidError(exceptions=exceptions)
|
||||
raise DatabaseInvalidError(exceptions=[DatabaseExistsValidationError()])
|
||||
|
|
|
|||
|
|
@ -126,7 +126,11 @@ class SqlResultExportCommand(BaseCommand):
|
|||
}:
|
||||
# remove extra row from `increased_limit`
|
||||
limit -= 1
|
||||
df = self._query.database.get_df(sql, self._query.schema)[:limit]
|
||||
df = self._query.database.get_df(
|
||||
sql,
|
||||
self._query.catalog,
|
||||
self._query.schema,
|
||||
)[:limit]
|
||||
|
||||
csv_data = csv.df_to_escaped_csv(df, index=False, **config["CSV_EXPORT"])
|
||||
|
||||
|
|
|
|||
|
|
@ -564,9 +564,9 @@ IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None
|
|||
#
|
||||
# Takes as a parameter the common bootstrap payload before transformations.
|
||||
# Returns a dict containing data that should be added or overridden to the payload.
|
||||
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[[dict[str, Any]], dict[str, Any]] = ( # noqa: E731
|
||||
lambda data: {}
|
||||
) # default: empty dict
|
||||
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[ # noqa: E731
|
||||
[dict[str, Any]], dict[str, Any]
|
||||
] = lambda data: {}
|
||||
|
||||
# EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes
|
||||
# example code for "My custom warm to hot" color scheme
|
||||
|
|
|
|||
|
|
@ -211,6 +211,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
|
|||
params = Column(String(1000))
|
||||
perm = Column(String(1000))
|
||||
schema_perm = Column(String(1000))
|
||||
catalog_perm = Column(String(1000), nullable=True, default=None)
|
||||
is_managed_externally = Column(Boolean, nullable=False, default=False)
|
||||
external_url = Column(Text, nullable=True)
|
||||
|
||||
|
|
@ -1261,9 +1262,20 @@ class SqlaTable(
|
|||
anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
|
||||
return Markup(anchor)
|
||||
|
||||
def get_catalog_perm(self) -> str | None:
|
||||
"""Returns catalog permission if present, database one otherwise."""
|
||||
return security_manager.get_catalog_perm(
|
||||
self.database.database_name,
|
||||
self.catalog,
|
||||
)
|
||||
|
||||
def get_schema_perm(self) -> str | None:
|
||||
"""Returns schema permission if present, database one otherwise."""
|
||||
return security_manager.get_schema_perm(self.database, self.schema or None)
|
||||
return security_manager.get_schema_perm(
|
||||
self.database.database_name,
|
||||
self.catalog,
|
||||
self.schema or None,
|
||||
)
|
||||
|
||||
def get_perm(self) -> str:
|
||||
"""
|
||||
|
|
@ -1282,7 +1294,10 @@ class SqlaTable(
|
|||
@property
|
||||
def full_name(self) -> str:
|
||||
return utils.get_datasource_full_name(
|
||||
self.database, self.table_name, schema=self.schema
|
||||
self.database,
|
||||
self.table_name,
|
||||
catalog=self.catalog,
|
||||
schema=self.schema,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
@ -1736,7 +1751,10 @@ class SqlaTable(
|
|||
|
||||
try:
|
||||
df = self.database.get_df(
|
||||
sql, self.schema or None, mutator=assign_column_label
|
||||
sql,
|
||||
None,
|
||||
self.schema or None,
|
||||
mutator=assign_column_label,
|
||||
)
|
||||
except (SupersetErrorException, SupersetErrorsException) as ex:
|
||||
# SupersetError(s) exception should not be captured; instead, they should
|
||||
|
|
@ -1870,34 +1888,45 @@ class SqlaTable(
|
|||
cls,
|
||||
database: Database,
|
||||
datasource_name: str,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> list[SqlaTable]:
|
||||
query = (
|
||||
db.session.query(cls)
|
||||
.filter_by(database_id=database.id)
|
||||
.filter_by(table_name=datasource_name)
|
||||
)
|
||||
filters = {
|
||||
"database_id": database.id,
|
||||
"table_name": datasource_name,
|
||||
}
|
||||
if catalog:
|
||||
filters["catalog"] = catalog
|
||||
if schema:
|
||||
query = query.filter_by(schema=schema)
|
||||
return query.all()
|
||||
filters["schema"] = schema
|
||||
|
||||
return db.session.query(cls).filter_by(**filters).all()
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_permissions( # pylint: disable=invalid-name
|
||||
cls,
|
||||
database: Database,
|
||||
permissions: set[str],
|
||||
catalog_perms: set[str],
|
||||
schema_perms: set[str],
|
||||
) -> list[SqlaTable]:
|
||||
# TODO(hughhhh): add unit test
|
||||
# remove empty sets from the query, since SQLAlchemy produces horrible SQL for
|
||||
# Model.column._in({}):
|
||||
#
|
||||
# table.column IN (SELECT 1 FROM (SELECT 1) WHERE 1!=1)
|
||||
filters = [
|
||||
method.in_(perms)
|
||||
for method, perms in zip(
|
||||
(SqlaTable.perm, SqlaTable.schema_perm, SqlaTable.catalog_perm),
|
||||
(permissions, schema_perms, catalog_perms),
|
||||
)
|
||||
if perms
|
||||
]
|
||||
|
||||
return (
|
||||
db.session.query(cls)
|
||||
.filter_by(database_id=database.id)
|
||||
.filter(
|
||||
or_(
|
||||
SqlaTable.perm.in_(permissions),
|
||||
SqlaTable.schema_perm.in_(schema_perms),
|
||||
)
|
||||
)
|
||||
.filter(or_(*filters))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = {
|
|||
"related_objects": "read",
|
||||
"tables": "read",
|
||||
"schemas": "read",
|
||||
"catalogs": "read",
|
||||
"select_star": "read",
|
||||
"table_metadata": "read",
|
||||
"table_metadata_deprecated": "read",
|
||||
|
|
|
|||
|
|
@ -73,10 +73,12 @@ from superset.daos.database import DatabaseDAO, DatabaseUserOAuth2TokensDAO
|
|||
from superset.databases.decorators import check_table_access
|
||||
from superset.databases.filters import DatabaseFilter, DatabaseUploadEnabledFilter
|
||||
from superset.databases.schemas import (
|
||||
CatalogsResponseSchema,
|
||||
ColumnarMetadataUploadFilePostSchema,
|
||||
ColumnarUploadPostSchema,
|
||||
CSVMetadataUploadFilePostSchema,
|
||||
CSVUploadPostSchema,
|
||||
database_catalogs_query_schema,
|
||||
database_schemas_query_schema,
|
||||
database_tables_query_schema,
|
||||
DatabaseConnectionSchema,
|
||||
|
|
@ -146,6 +148,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
"table_extra_metadata",
|
||||
"table_extra_metadata_deprecated",
|
||||
"select_star",
|
||||
"catalogs",
|
||||
"schemas",
|
||||
"test_connection",
|
||||
"related_objects",
|
||||
|
|
@ -266,6 +269,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
edit_model_schema = DatabasePutSchema()
|
||||
|
||||
apispec_parameter_schemas = {
|
||||
"database_catalogs_query_schema": database_catalogs_query_schema,
|
||||
"database_schemas_query_schema": database_schemas_query_schema,
|
||||
"database_tables_query_schema": database_tables_query_schema,
|
||||
"get_export_ids_schema": get_export_ids_schema,
|
||||
|
|
@ -273,6 +277,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
|
||||
openapi_spec_tag = "Database"
|
||||
openapi_spec_component_schemas = (
|
||||
CatalogsResponseSchema,
|
||||
ColumnarUploadPostSchema,
|
||||
CSVUploadPostSchema,
|
||||
DatabaseConnectionSchema,
|
||||
|
|
@ -604,6 +609,70 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
)
|
||||
return self.response_422(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>/catalogs/")
|
||||
@protect()
|
||||
@safe
|
||||
@rison(database_catalogs_query_schema)
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".catalogs",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def catalogs(self, pk: int, **kwargs: Any) -> FlaskResponse:
|
||||
"""Get all catalogs from a database.
|
||||
---
|
||||
get:
|
||||
summary: Get all catalogs from a database
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
name: pk
|
||||
description: The database id
|
||||
- in: query
|
||||
name: q
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/database_catalogs_query_schema'
|
||||
responses:
|
||||
200:
|
||||
description: A List of all catalogs from the database
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/CatalogsResponseSchema"
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
database = DatabaseDAO.find_by_id(pk)
|
||||
if not database:
|
||||
return self.response_404()
|
||||
try:
|
||||
catalogs = database.get_all_catalog_names(
|
||||
cache=database.catalog_cache_enabled,
|
||||
cache_timeout=database.catalog_cache_timeout or None,
|
||||
force=kwargs["rison"].get("force", False),
|
||||
)
|
||||
catalogs = security_manager.get_catalogs_accessible_by_user(
|
||||
database,
|
||||
catalogs,
|
||||
)
|
||||
return self.response(200, result=list(catalogs))
|
||||
except OperationalError:
|
||||
return self.response(
|
||||
500,
|
||||
message="There was an error connecting to the database",
|
||||
)
|
||||
except SupersetException as ex:
|
||||
return self.response(ex.status, message=ex.message)
|
||||
|
||||
@expose("/<int:pk>/schemas/")
|
||||
@protect()
|
||||
@safe
|
||||
|
|
@ -650,13 +719,19 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
if not database:
|
||||
return self.response_404()
|
||||
try:
|
||||
catalog = kwargs["rison"].get("catalog")
|
||||
schemas = database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=database.schema_cache_enabled,
|
||||
cache_timeout=database.schema_cache_timeout or None,
|
||||
force=kwargs["rison"].get("force", False),
|
||||
)
|
||||
schemas = security_manager.get_schemas_accessible_by_user(database, schemas)
|
||||
return self.response(200, result=schemas)
|
||||
schemas = security_manager.get_schemas_accessible_by_user(
|
||||
database,
|
||||
catalog,
|
||||
schemas,
|
||||
)
|
||||
return self.response(200, result=list(schemas))
|
||||
except OperationalError:
|
||||
return self.response(
|
||||
500, message="There was an error connecting to the database"
|
||||
|
|
@ -718,10 +793,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
force = kwargs["rison"].get("force", False)
|
||||
catalog_name = kwargs["rison"].get("catalog_name")
|
||||
schema_name = kwargs["rison"].get("schema_name", "")
|
||||
|
||||
try:
|
||||
command = TablesDatabaseCommand(pk, schema_name, force)
|
||||
command = TablesDatabaseCommand(pk, catalog_name, schema_name, force)
|
||||
payload = command.run()
|
||||
return self.response(200, **payload)
|
||||
except DatabaseNotFoundError:
|
||||
|
|
|
|||
|
|
@ -28,11 +28,12 @@ from superset.models.core import Database
|
|||
from superset.views.base import BaseFilter
|
||||
|
||||
|
||||
def can_access_databases(
|
||||
view_menu_name: str,
|
||||
) -> set[str]:
|
||||
def can_access_databases(view_menu_name: str) -> set[str]:
|
||||
"""
|
||||
Return names of databases available in `view_menu_name`.
|
||||
"""
|
||||
return {
|
||||
security_manager.unpack_database_and_schema(vm).database
|
||||
vm.split(".")[0][1:-1]
|
||||
for vm in security_manager.user_view_menu_names(view_menu_name)
|
||||
}
|
||||
|
||||
|
|
@ -56,17 +57,21 @@ class DatabaseFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
|||
# We can proceed with default filtering now
|
||||
if security_manager.can_access_all_databases():
|
||||
return query
|
||||
database_perms = security_manager.user_view_menu_names("database_access")
|
||||
schema_access_databases = can_access_databases("schema_access")
|
||||
|
||||
database_perms = security_manager.user_view_menu_names("database_access")
|
||||
catalog_access_databases = can_access_databases("catalog_access")
|
||||
schema_access_databases = can_access_databases("schema_access")
|
||||
datasource_access_databases = can_access_databases("datasource_access")
|
||||
database_names = sorted(
|
||||
catalog_access_databases
|
||||
| schema_access_databases
|
||||
| datasource_access_databases
|
||||
)
|
||||
|
||||
return query.filter(
|
||||
or_(
|
||||
self.model.perm.in_(database_perms),
|
||||
self.model.database_name.in_(
|
||||
[*schema_access_databases, *datasource_access_databases]
|
||||
),
|
||||
self.model.database_name.in_(database_names),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,14 @@ from superset.security.analytics_db_safety import check_sqlalchemy_uri
|
|||
from superset.utils.core import markdown, parse_ssl_cert
|
||||
|
||||
database_schemas_query_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"force": {"type": "boolean"},
|
||||
"catalog": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
database_catalogs_query_schema = {
|
||||
"type": "object",
|
||||
"properties": {"force": {"type": "boolean"}},
|
||||
}
|
||||
|
|
@ -65,6 +73,7 @@ database_tables_query_schema = {
|
|||
"properties": {
|
||||
"force": {"type": "boolean"},
|
||||
"schema_name": {"type": "string"},
|
||||
"catalog_name": {"type": "string"},
|
||||
},
|
||||
"required": ["schema_name"],
|
||||
}
|
||||
|
|
@ -712,6 +721,12 @@ class SchemasResponseSchema(Schema):
|
|||
)
|
||||
|
||||
|
||||
class CatalogsResponseSchema(Schema):
|
||||
result = fields.List(
|
||||
fields.String(metadata={"description": "A database catalog name"})
|
||||
)
|
||||
|
||||
|
||||
class DatabaseTablesResponse(Schema):
|
||||
extra = fields.Dict(
|
||||
metadata={"description": "Extra data used to specify column metadata"}
|
||||
|
|
|
|||
|
|
@ -404,6 +404,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
# Does the DB support catalogs? A catalog here is a group of schemas, and has
|
||||
# different names depending on the DB: BigQuery calles it a "project", Postgres calls
|
||||
# it a "database", Trino calls it a "catalog", etc.
|
||||
#
|
||||
# When this is changed to true in a DB engine spec it MUST support the
|
||||
# `get_default_catalog` and `get_catalog_names` methods. In addition, you MUST write
|
||||
# a database migration updating any existing schema permissions.
|
||||
supports_catalog = False
|
||||
|
||||
# Can the catalog be changed on a per-query basis?
|
||||
|
|
@ -638,10 +642,20 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
return driver in cls.drivers
|
||||
|
||||
@classmethod
|
||||
def get_default_catalog(
|
||||
cls,
|
||||
database: Database, # pylint: disable=unused-argument
|
||||
) -> str | None:
|
||||
"""
|
||||
Return the default catalog for a given database.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_default_schema(cls, database: Database, catalog: str | None) -> str | None:
|
||||
"""
|
||||
Return the default schema in a given database.
|
||||
Return the default schema for a catalog in a given database.
|
||||
"""
|
||||
with database.get_inspector(catalog=catalog) as inspector:
|
||||
return inspector.default_schema_name
|
||||
|
|
@ -1412,24 +1426,24 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
) -> list[str]:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all catalogs from database.
|
||||
|
||||
This needs to be implemented per database, since SQLAlchemy doesn't offer an
|
||||
abstraction.
|
||||
"""
|
||||
return []
|
||||
return set()
|
||||
|
||||
@classmethod
|
||||
def get_schema_names(cls, inspector: Inspector) -> list[str]:
|
||||
def get_schema_names(cls, inspector: Inspector) -> set[str]:
|
||||
"""
|
||||
Get all schemas from database
|
||||
|
||||
:param inspector: SqlAlchemy inspector
|
||||
:return: All schemas in the database
|
||||
"""
|
||||
return sorted(inspector.get_schema_names())
|
||||
return set(inspector.get_schema_names())
|
||||
|
||||
@classmethod
|
||||
def get_table_names( # pylint: disable=unused-argument
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
|
||||
allows_hidden_cc_in_orderby = True
|
||||
|
||||
supports_catalog = True
|
||||
supports_catalog = False
|
||||
|
||||
"""
|
||||
https://www.python.org/dev/peps/pep-0249/#arraysize
|
||||
|
|
@ -464,7 +464,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
) -> list[str]:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all catalogs.
|
||||
|
||||
|
|
@ -475,7 +475,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
client = cls._get_client(engine)
|
||||
projects = client.list_projects()
|
||||
|
||||
return sorted(project.project_id for project in projects)
|
||||
return {project.project_id for project in projects}
|
||||
|
||||
@classmethod
|
||||
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def get_function_names(cls, database: Database) -> list[str]:
|
||||
# pylint: disable=import-outside-toplevel,import-error
|
||||
# pylint: disable=import-outside-toplevel, import-error
|
||||
from clickhouse_connect.driver.exceptions import ClickHouseError
|
||||
|
||||
if cls._function_names:
|
||||
|
|
@ -340,7 +340,7 @@ class ClickHouseConnectEngineSpec(BasicParametersMixin, ClickHouseEngineSpec):
|
|||
def validate_parameters(
|
||||
cls, properties: BasicPropertiesType
|
||||
) -> list[SupersetError]:
|
||||
# pylint: disable=import-outside-toplevel,import-error
|
||||
# pylint: disable=import-outside-toplevel, import-error
|
||||
from clickhouse_connect.driver import default_port
|
||||
|
||||
parameters = properties.get("parameters", {})
|
||||
|
|
|
|||
|
|
@ -74,13 +74,12 @@ class ImpalaEngineSpec(BaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def get_schema_names(cls, inspector: Inspector) -> list[str]:
|
||||
schemas = [
|
||||
def get_schema_names(cls, inspector: Inspector) -> set[str]:
|
||||
return {
|
||||
row[0]
|
||||
for row in inspector.engine.execute("SHOW SCHEMAS")
|
||||
if not row[0].startswith("_")
|
||||
]
|
||||
return schemas
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def has_implicit_cancel(cls) -> bool:
|
||||
|
|
|
|||
|
|
@ -101,8 +101,6 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
|||
engine = ""
|
||||
engine_name = "PostgreSQL"
|
||||
|
||||
supports_catalog = True
|
||||
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
TimeGrain.SECOND: "DATE_TRUNC('second', {col})",
|
||||
|
|
@ -199,7 +197,10 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
|||
class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
||||
engine = "postgresql"
|
||||
engine_aliases = {"postgres"}
|
||||
|
||||
supports_dynamic_schema = True
|
||||
supports_catalog = True
|
||||
supports_dynamic_catalog = True
|
||||
|
||||
default_driver = "psycopg2"
|
||||
sqlalchemy_uri_placeholder = (
|
||||
|
|
@ -296,6 +297,29 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
|||
|
||||
return super().get_default_schema_for_query(database, query)
|
||||
|
||||
@classmethod
|
||||
def adjust_engine_params(
|
||||
cls,
|
||||
uri: URL,
|
||||
connect_args: dict[str, Any],
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
"""
|
||||
Set the catalog (database).
|
||||
"""
|
||||
if catalog:
|
||||
uri = uri.set(database=catalog)
|
||||
|
||||
return uri, connect_args
|
||||
|
||||
@classmethod
|
||||
def get_default_catalog(cls, database: Database) -> str | None:
|
||||
"""
|
||||
Return the default catalog for a given database.
|
||||
"""
|
||||
return database.url_object.database
|
||||
|
||||
@classmethod
|
||||
def get_prequeries(
|
||||
cls,
|
||||
|
|
@ -346,13 +370,13 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
|||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
) -> list[str]:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Return all catalogs.
|
||||
|
||||
In Postgres, a catalog is called a "database".
|
||||
"""
|
||||
return sorted(
|
||||
return {
|
||||
catalog
|
||||
for (catalog,) in inspector.bind.execute(
|
||||
"""
|
||||
|
|
@ -360,7 +384,7 @@ SELECT datname FROM pg_database
|
|||
WHERE datistemplate = false;
|
||||
"""
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_table_names(
|
||||
|
|
|
|||
|
|
@ -648,6 +648,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
engine_name = "Presto"
|
||||
allows_alias_to_source_column = False
|
||||
|
||||
supports_catalog = False
|
||||
|
||||
custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
|
||||
COLUMN_DOES_NOT_EXIST_REGEX: (
|
||||
__(
|
||||
|
|
@ -815,11 +817,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
cls,
|
||||
database: Database,
|
||||
inspector: Inspector,
|
||||
) -> list[str]:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all catalogs.
|
||||
"""
|
||||
return [catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")]
|
||||
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}
|
||||
|
||||
@classmethod
|
||||
def _create_column_info(
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||
sqlalchemy_uri_placeholder = "snowflake://"
|
||||
|
||||
supports_dynamic_schema = True
|
||||
supports_catalog = True
|
||||
supports_catalog = False
|
||||
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
|
|
@ -174,18 +174,18 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||
cls,
|
||||
database: "Database",
|
||||
inspector: Inspector,
|
||||
) -> list[str]:
|
||||
) -> set[str]:
|
||||
"""
|
||||
Return all catalogs.
|
||||
|
||||
In Snowflake, a catalog is called a "database".
|
||||
"""
|
||||
return sorted(
|
||||
return {
|
||||
catalog
|
||||
for (catalog,) in inspector.bind.execute(
|
||||
"SELECT DATABASE_NAME from information_schema.databases"
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls) -> str:
|
||||
|
|
|
|||
|
|
@ -270,11 +270,6 @@ class SupersetShillelaghAdapter(Adapter):
|
|||
self.schema = parts.pop(-1) if parts else None
|
||||
self.catalog = parts.pop(-1) if parts else None
|
||||
|
||||
if self.catalog:
|
||||
# TODO (betodealmeida): when SIP-95 is implemented we should check to see if
|
||||
# the database has multi-catalog enabled, and if so, give access.
|
||||
raise NotImplementedError("Catalogs are not currently supported")
|
||||
|
||||
# If the table has a single integer primary key we use that as the row ID in order
|
||||
# to perform updates and deletes. Otherwise we can only do inserts and selects.
|
||||
self._rowid: str | None = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,116 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from alembic import op
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upgrade_schema_perms(engine: str | None = None) -> None:
|
||||
"""
|
||||
Update schema permissions to include the catalog part.
|
||||
|
||||
Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
|
||||
introduction of catalogs, any existing permissions need to be renamed to include the
|
||||
catalog: `[db].[catalog].[schema]`.
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
session = db.Session(bind=bind)
|
||||
for database in session.query(Database).all():
|
||||
db_engine_spec = database.db_engine_spec
|
||||
if (
|
||||
engine and db_engine_spec.engine != engine
|
||||
) or not db_engine_spec.supports_catalog:
|
||||
continue
|
||||
|
||||
catalog = database.get_default_catalog()
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
for schema in database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
):
|
||||
perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
perm,
|
||||
)
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu.name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
||||
def downgrade_schema_perms(engine: str | None = None) -> None:
|
||||
"""
|
||||
Update schema permissions to not have the catalog part.
|
||||
|
||||
Before SIP-95 schema permissions were stored in the format `[db].[schema]`. With the
|
||||
introduction of catalogs, any existing permissions need to be renamed to include the
|
||||
catalog: `[db].[catalog].[schema]`.
|
||||
|
||||
This helped function reverts the process.
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
session = db.Session(bind=bind)
|
||||
for database in session.query(Database).all():
|
||||
db_engine_spec = database.db_engine_spec
|
||||
if (
|
||||
engine and db_engine_spec.engine != engine
|
||||
) or not db_engine_spec.supports_catalog:
|
||||
continue
|
||||
|
||||
catalog = database.get_default_catalog()
|
||||
ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
for schema in database.get_all_schema_names(
|
||||
catalog=catalog,
|
||||
cache=False,
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
):
|
||||
perm = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
existing_pvm = security_manager.find_permission_view_menu(
|
||||
"schema_access",
|
||||
perm,
|
||||
)
|
||||
if existing_pvm:
|
||||
existing_pvm.view_menu.name = security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
None,
|
||||
schema,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Add catalog_perm to tables
|
||||
|
||||
Revision ID: 58d051681a3b
|
||||
Revises: 4a33124c18ad
|
||||
Create Date: 2024-05-01 10:52:31.458433
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
from superset.migrations.shared.catalogs import (
|
||||
downgrade_schema_perms,
|
||||
upgrade_schema_perms,
|
||||
)
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "58d051681a3b"
|
||||
down_revision = "4a33124c18ad"
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"tables",
|
||||
sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"slices",
|
||||
sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
|
||||
)
|
||||
upgrade_schema_perms(engine="postgresql")
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("slices", "catalog_perm")
|
||||
op.drop_column("tables", "catalog_perm")
|
||||
downgrade_schema_perms(engine="postgresql")
|
||||
|
|
@ -78,7 +78,7 @@ from superset.sql_parse import Table
|
|||
from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType
|
||||
from superset.utils import cache as cache_util, core as utils
|
||||
from superset.utils.backports import StrEnum
|
||||
from superset.utils.core import get_username
|
||||
from superset.utils.core import DatasourceName, get_username
|
||||
from superset.utils.oauth2 import get_oauth2_access_token
|
||||
|
||||
config = app.config
|
||||
|
|
@ -313,6 +313,14 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
def metadata_cache_timeout(self) -> dict[str, Any]:
|
||||
return self.get_extra().get("metadata_cache_timeout", {})
|
||||
|
||||
@property
|
||||
def catalog_cache_enabled(self) -> bool:
|
||||
return "catalog_cache_timeout" in self.metadata_cache_timeout
|
||||
|
||||
@property
|
||||
def catalog_cache_timeout(self) -> int | None:
|
||||
return self.metadata_cache_timeout.get("catalog_cache_timeout")
|
||||
|
||||
@property
|
||||
def schema_cache_enabled(self) -> bool:
|
||||
return "schema_cache_timeout" in self.metadata_cache_timeout
|
||||
|
|
@ -549,6 +557,18 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
|
||||
yield conn
|
||||
|
||||
def get_default_catalog(self) -> str | None:
|
||||
"""
|
||||
Return the default configured catalog for the database.
|
||||
"""
|
||||
return self.db_engine_spec.get_default_catalog(self)
|
||||
|
||||
def get_default_schema(self, catalog: str | None) -> str | None:
|
||||
"""
|
||||
Return the default schema for the database.
|
||||
"""
|
||||
return self.db_engine_spec.get_default_schema(self, catalog)
|
||||
|
||||
def get_default_schema_for_query(self, query: Query) -> str | None:
|
||||
"""
|
||||
Return the default schema for a given query.
|
||||
|
|
@ -706,19 +726,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
key="db:{self.id}:schema:{schema}:table_list",
|
||||
cache=cache_manager.cache,
|
||||
)
|
||||
def get_all_table_names_in_schema( # pylint: disable=unused-argument
|
||||
def get_all_table_names_in_schema(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str,
|
||||
cache: bool = False,
|
||||
cache_timeout: int | None = None,
|
||||
force: bool = False,
|
||||
) -> set[tuple[str, str]]:
|
||||
) -> set[DatasourceName]:
|
||||
"""Parameters need to be passed as keyword arguments.
|
||||
|
||||
For unused parameters, they are referenced in
|
||||
cache_util.memoized_func decorator.
|
||||
|
||||
:param catalog: optional catalog name
|
||||
:param schema: schema name
|
||||
:param cache: whether cache is enabled for the function
|
||||
:param cache_timeout: timeout in seconds for the cache
|
||||
|
|
@ -728,7 +746,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
try:
|
||||
with self.get_inspector(catalog=catalog, schema=schema) as inspector:
|
||||
return {
|
||||
(table, schema)
|
||||
DatasourceName(table, schema, catalog)
|
||||
for table in self.db_engine_spec.get_table_names(
|
||||
database=self,
|
||||
inspector=inspector,
|
||||
|
|
@ -742,19 +760,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
key="db:{self.id}:schema:{schema}:view_list",
|
||||
cache=cache_manager.cache,
|
||||
)
|
||||
def get_all_view_names_in_schema( # pylint: disable=unused-argument
|
||||
def get_all_view_names_in_schema(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str,
|
||||
cache: bool = False,
|
||||
cache_timeout: int | None = None,
|
||||
force: bool = False,
|
||||
) -> set[tuple[str, str]]:
|
||||
) -> set[DatasourceName]:
|
||||
"""Parameters need to be passed as keyword arguments.
|
||||
|
||||
For unused parameters, they are referenced in
|
||||
cache_util.memoized_func decorator.
|
||||
|
||||
:param catalog: optional catalog name
|
||||
:param schema: schema name
|
||||
:param cache: whether cache is enabled for the function
|
||||
:param cache_timeout: timeout in seconds for the cache
|
||||
|
|
@ -764,7 +780,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
try:
|
||||
with self.get_inspector(catalog=catalog, schema=schema) as inspector:
|
||||
return {
|
||||
(view, schema)
|
||||
DatasourceName(view, schema, catalog)
|
||||
for view in self.db_engine_spec.get_view_names(
|
||||
database=self,
|
||||
inspector=inspector,
|
||||
|
|
@ -792,22 +808,17 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
key="db:{self.id}:schema_list",
|
||||
cache=cache_manager.cache,
|
||||
)
|
||||
def get_all_schema_names( # pylint: disable=unused-argument
|
||||
def get_all_schema_names(
|
||||
self,
|
||||
*,
|
||||
catalog: str | None = None,
|
||||
cache: bool = False,
|
||||
cache_timeout: int | None = None,
|
||||
force: bool = False,
|
||||
ssh_tunnel: SSHTunnel | None = None,
|
||||
) -> list[str]:
|
||||
"""Parameters need to be passed as keyword arguments.
|
||||
) -> set[str]:
|
||||
"""
|
||||
Return the schemas in a given database
|
||||
|
||||
For unused parameters, they are referenced in
|
||||
cache_util.memoized_func decorator.
|
||||
|
||||
:param cache: whether cache is enabled for the function
|
||||
:param cache_timeout: timeout in seconds for the cache
|
||||
:param force: whether to force refresh the cache
|
||||
:param catalog: override default catalog
|
||||
:param ssh_tunnel: SSH tunnel information needed to establish a connection
|
||||
:return: schema list
|
||||
"""
|
||||
try:
|
||||
|
|
@ -819,6 +830,27 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
|||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key="db:{self.id}:catalog_list",
|
||||
cache=cache_manager.cache,
|
||||
)
|
||||
def get_all_catalog_names(
|
||||
self,
|
||||
*,
|
||||
ssh_tunnel: SSHTunnel | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Return the catalogs in a given database
|
||||
|
||||
:param ssh_tunnel: SSH tunnel information needed to establish a connection
|
||||
:return: catalog list
|
||||
"""
|
||||
try:
|
||||
with self.get_inspector(ssh_tunnel=ssh_tunnel) as inspector:
|
||||
return self.db_engine_spec.get_catalog_names(self, inspector)
|
||||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
@property
|
||||
def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]:
|
||||
url = make_url_safe(self.sqlalchemy_uri_decrypted)
|
||||
|
|
|
|||
|
|
@ -209,7 +209,10 @@ class ImportExportMixin:
|
|||
"""Get a mapping of foreign name to the local name of foreign keys"""
|
||||
parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
|
||||
if parent_rel:
|
||||
return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs} # noqa: E741
|
||||
return {
|
||||
local.name: remote.name
|
||||
for (local, remote) in parent_rel.local_remote_pairs
|
||||
}
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -772,6 +775,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
def database(self) -> "Database":
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def catalog(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def schema(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -1025,7 +1032,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
return df
|
||||
|
||||
try:
|
||||
df = self.database.get_df(sql, self.schema, mutator=assign_column_label)
|
||||
df = self.database.get_df(
|
||||
sql,
|
||||
self.catalog,
|
||||
self.schema,
|
||||
mutator=assign_column_label,
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
df = pd.DataFrame()
|
||||
status = QueryStatus.FAILED
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ class Slice( # pylint: disable=too-many-public-methods
|
|||
cache_timeout = Column(Integer)
|
||||
perm = Column(String(1000))
|
||||
schema_perm = Column(String(1000))
|
||||
catalog_perm = Column(String(1000), nullable=True, default=None)
|
||||
# the last time a user has saved the chart, changed_on is referencing
|
||||
# when the database row was last written
|
||||
last_saved_at = Column(DateTime, nullable=True)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import logging
|
|||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING
|
||||
|
||||
from flask import current_app, Flask, g, Request
|
||||
from flask_appbuilder import Model
|
||||
|
|
@ -67,7 +67,7 @@ from superset.security.guest_token import (
|
|||
GuestTokenUser,
|
||||
GuestUser,
|
||||
)
|
||||
from superset.sql_parse import extract_tables_from_jinja_sql
|
||||
from superset.sql_parse import extract_tables_from_jinja_sql, Table
|
||||
from superset.superset_typing import Metric
|
||||
from superset.utils.core import (
|
||||
DatasourceName,
|
||||
|
|
@ -89,7 +89,6 @@ if TYPE_CHECKING:
|
|||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import Table
|
||||
from superset.viz import BaseViz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -97,8 +96,9 @@ logger = logging.getLogger(__name__)
|
|||
DATABASE_PERM_REGEX = re.compile(r"^\[.+\]\.\(id\:(?P<id>\d+)\)$")
|
||||
|
||||
|
||||
class DatabaseAndSchema(NamedTuple):
|
||||
class DatabaseCatalogSchema(NamedTuple):
|
||||
database: str
|
||||
catalog: Optional[str]
|
||||
schema: str
|
||||
|
||||
|
||||
|
|
@ -346,17 +346,51 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
return self.get_guest_user_from_request(request)
|
||||
return None
|
||||
|
||||
def get_catalog_perm(
|
||||
self,
|
||||
database: str,
|
||||
catalog: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return the database specific catalog permission.
|
||||
|
||||
:param database: The Superset database or database name
|
||||
:param catalog: The database catalog name
|
||||
:return: The database specific schema permission
|
||||
"""
|
||||
if catalog is None:
|
||||
return None
|
||||
|
||||
return f"[{database}].[{catalog}]"
|
||||
|
||||
def get_schema_perm(
|
||||
self, database: Union["Database", str], schema: Optional[str] = None
|
||||
self,
|
||||
database: str,
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return the database specific schema permission.
|
||||
|
||||
:param database: The Superset database or database name
|
||||
:param schema: The Superset schema name
|
||||
Catalogs were added in SIP-95, and not all databases support them. Because of
|
||||
this, the format used for permissions is different depending on whether a
|
||||
catalog is passed or not:
|
||||
|
||||
[database].[schema]
|
||||
[database].[catalog].[schema]
|
||||
|
||||
:param database: The database name
|
||||
:param catalog: The database catalog name
|
||||
:param schema: The database schema name
|
||||
:return: The database specific schema permission
|
||||
"""
|
||||
return f"[{database}].[{schema}]" if schema else None
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
if catalog:
|
||||
return f"[{database}].[{catalog}].[{schema}]"
|
||||
|
||||
return f"[{database}].[{schema}]"
|
||||
|
||||
@staticmethod
|
||||
def get_database_perm(database_id: int, database_name: str) -> Optional[str]:
|
||||
|
|
@ -370,13 +404,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
) -> Optional[str]:
|
||||
return f"[{database_name}].[{dataset_name}](id:{dataset_id})"
|
||||
|
||||
def unpack_database_and_schema(self, schema_permission: str) -> DatabaseAndSchema:
|
||||
# [database_name].[schema|table]
|
||||
|
||||
schema_name = schema_permission.split(".")[1][1:-1]
|
||||
database_name = schema_permission.split(".")[0][1:-1]
|
||||
return DatabaseAndSchema(database_name, schema_name)
|
||||
|
||||
def can_access(self, permission_name: str, view_name: str) -> bool:
|
||||
"""
|
||||
Return True if the user can access the FAB permission/view, False otherwise.
|
||||
|
|
@ -436,6 +463,17 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
or self.can_access("database_access", database.perm) # type: ignore
|
||||
)
|
||||
|
||||
def can_access_catalog(self, database: "Database", catalog: str) -> bool:
|
||||
"""
|
||||
Return if the user can access the specified catalog.
|
||||
"""
|
||||
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
|
||||
return bool(
|
||||
self.can_access_all_datasources()
|
||||
or self.can_access_database(database)
|
||||
or (catalog_perm and self.can_access("catalog_access", catalog_perm))
|
||||
)
|
||||
|
||||
def can_access_schema(self, datasource: "BaseDatasource") -> bool:
|
||||
"""
|
||||
Return True if the user can access the schema associated with specified
|
||||
|
|
@ -448,6 +486,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
return (
|
||||
self.can_access_all_datasources()
|
||||
or self.can_access_database(datasource.database)
|
||||
or self.can_access_catalog(datasource.database, datasource.catalog)
|
||||
or self.can_access("schema_access", datasource.schema_perm or "")
|
||||
)
|
||||
|
||||
|
|
@ -706,55 +745,150 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
]
|
||||
|
||||
def get_schemas_accessible_by_user(
|
||||
self, database: "Database", schemas: list[str], hierarchical: bool = True
|
||||
) -> list[str]:
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: Optional[str],
|
||||
schemas: set[str],
|
||||
hierarchical: bool = True,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Return the list of SQL schemas accessible by the user.
|
||||
Returned a filtered list of the schemas accessible by the user.
|
||||
|
||||
If not catalog is specified, the default catalog is used.
|
||||
|
||||
:param database: The SQL database
|
||||
:param schemas: The list of eligible SQL schemas
|
||||
:param catalog: An optional database catalog
|
||||
:param schemas: A set of candidate schemas
|
||||
:param hierarchical: Whether to check using the hierarchical permission logic
|
||||
:returns: The list of accessible SQL schemas
|
||||
:returns: The set of accessible database schemas
|
||||
"""
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
if hierarchical and self.can_access_database(database):
|
||||
if hierarchical and (
|
||||
self.can_access_database(database)
|
||||
or (catalog and self.can_access_catalog(database, catalog))
|
||||
):
|
||||
return schemas
|
||||
|
||||
# schema_access
|
||||
accessible_schemas = {
|
||||
self.unpack_database_and_schema(s).schema
|
||||
for s in self.user_view_menu_names("schema_access")
|
||||
if s.startswith(f"[{database}].")
|
||||
}
|
||||
accessible_schemas: set[str] = set()
|
||||
schema_access = self.user_view_menu_names("schema_access")
|
||||
default_catalog = database.get_default_catalog()
|
||||
default_schema = database.get_default_schema(default_catalog)
|
||||
|
||||
for perm in schema_access:
|
||||
parts = [part[1:-1] for part in perm.split(".")]
|
||||
|
||||
if parts[0] != database.database_name:
|
||||
continue
|
||||
|
||||
# [database].[schema] matches when no catalog is specified, or when the user
|
||||
# specifies the default catalog
|
||||
if len(parts) == 2 and (catalog is None or catalog == default_catalog):
|
||||
accessible_schemas.add(parts[1])
|
||||
|
||||
# [database].[catalog].[schema] matches when the catalog is equal to the
|
||||
# requested catalog or, when no catalog specified, it's equal to the default
|
||||
# catalog.
|
||||
elif len(parts) == 3 and parts[1] == (catalog or default_catalog):
|
||||
accessible_schemas.add(parts[2])
|
||||
|
||||
# datasource_access
|
||||
if perms := self.user_view_menu_names("datasource_access"):
|
||||
tables = (
|
||||
self.get_session.query(SqlaTable.schema)
|
||||
.filter(SqlaTable.database_id == database.id)
|
||||
.filter(SqlaTable.schema.isnot(None))
|
||||
.filter(SqlaTable.schema != "")
|
||||
.filter(or_(SqlaTable.perm.in_(perms)))
|
||||
.distinct()
|
||||
)
|
||||
accessible_schemas.update([table.schema for table in tables])
|
||||
accessible_schemas.update(
|
||||
{
|
||||
table.schema or default_schema # type: ignore
|
||||
for table in tables
|
||||
if (table.schema or default_schema)
|
||||
}
|
||||
)
|
||||
|
||||
return [s for s in schemas if s in accessible_schemas]
|
||||
return schemas & accessible_schemas
|
||||
|
||||
def get_catalogs_accessible_by_user(
|
||||
self,
|
||||
database: "Database",
|
||||
catalogs: set[str],
|
||||
hierarchical: bool = True,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Returned a filtered list of the catalogs accessible by the user.
|
||||
|
||||
:param database: The SQL database
|
||||
:param catalogs: A set of candidate catalogs
|
||||
:param hierarchical: Whether to check using the hierarchical permission logic
|
||||
:returns: The set of accessible database catalogs
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
if hierarchical and self.can_access_database(database):
|
||||
return catalogs
|
||||
|
||||
# catalog access
|
||||
accessible_catalogs: set[str] = set()
|
||||
catalog_access = self.user_view_menu_names("catalog_access")
|
||||
default_catalog = database.get_default_catalog()
|
||||
|
||||
for perm in catalog_access:
|
||||
parts = [part[1:-1] for part in perm.split(".")]
|
||||
if parts[0] == database.database_name:
|
||||
accessible_catalogs.add(parts[1])
|
||||
|
||||
# schema access
|
||||
schema_access = self.user_view_menu_names("schema_access")
|
||||
for perm in schema_access:
|
||||
parts = [part[1:-1] for part in perm.split(".")]
|
||||
|
||||
if parts[0] != database.database_name:
|
||||
continue
|
||||
if len(parts) == 2 and default_catalog:
|
||||
accessible_catalogs.add(default_catalog)
|
||||
elif len(parts) == 3:
|
||||
accessible_catalogs.add(parts[2])
|
||||
|
||||
# datasource_access
|
||||
if perms := self.user_view_menu_names("datasource_access"):
|
||||
tables = (
|
||||
self.get_session.query(SqlaTable.schema)
|
||||
.filter(SqlaTable.database_id == database.id)
|
||||
.filter(or_(SqlaTable.perm.in_(perms)))
|
||||
.distinct()
|
||||
)
|
||||
accessible_catalogs.update(
|
||||
{
|
||||
table.catalog or default_catalog # type: ignore
|
||||
for table in tables
|
||||
if (table.catalog or default_catalog)
|
||||
}
|
||||
)
|
||||
|
||||
return catalogs & accessible_catalogs
|
||||
|
||||
def get_datasources_accessible_by_user( # pylint: disable=invalid-name
|
||||
self,
|
||||
database: "Database",
|
||||
datasource_names: list[DatasourceName],
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> list[DatasourceName]:
|
||||
"""
|
||||
Return the list of SQL tables accessible by the user.
|
||||
Filter list of SQL tables to the ones accessible by the user.
|
||||
|
||||
When catalog and/or schema are specified, it's assumed that all datasources in
|
||||
`datasource_names` are in the given catalog/schema.
|
||||
|
||||
:param database: The SQL database
|
||||
:param datasource_names: The list of eligible SQL tables w/ schema
|
||||
:param catalog: The fallback SQL catalog if not present in the table name
|
||||
:param schema: The fallback SQL schema if not present in the table name
|
||||
:returns: The list of accessible SQL tables w/ schema
|
||||
"""
|
||||
|
|
@ -764,22 +898,34 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
if self.can_access_database(database):
|
||||
return datasource_names
|
||||
|
||||
if catalog:
|
||||
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
|
||||
if catalog_perm and self.can_access("catalog_access", catalog_perm):
|
||||
return datasource_names
|
||||
|
||||
if schema:
|
||||
schema_perm = self.get_schema_perm(database, schema)
|
||||
schema_perm = self.get_schema_perm(database, catalog, schema)
|
||||
if schema_perm and self.can_access("schema_access", schema_perm):
|
||||
return datasource_names
|
||||
|
||||
user_perms = self.user_view_menu_names("datasource_access")
|
||||
catalog_perms = self.user_view_menu_names("catalog_access")
|
||||
schema_perms = self.user_view_menu_names("schema_access")
|
||||
user_datasources = SqlaTable.query_datasources_by_permissions(
|
||||
database, user_perms, schema_perms
|
||||
)
|
||||
if schema:
|
||||
names = {d.table_name for d in user_datasources if d.schema == schema}
|
||||
return [d for d in datasource_names if d.table in names]
|
||||
user_datasources = {
|
||||
DatasourceName(table.table_name, table.schema, table.catalog)
|
||||
for table in SqlaTable.query_datasources_by_permissions(
|
||||
database,
|
||||
user_perms,
|
||||
catalog_perms,
|
||||
schema_perms,
|
||||
)
|
||||
}
|
||||
|
||||
full_names = {d.full_name for d in user_datasources}
|
||||
return [d for d in datasource_names if f"[{database}].[{d}]" in full_names]
|
||||
return [
|
||||
datasource
|
||||
for datasource in datasource_names
|
||||
if datasource in user_datasources
|
||||
]
|
||||
|
||||
def merge_perm(self, permission_name: str, view_menu_name: str) -> None:
|
||||
"""
|
||||
|
|
@ -844,6 +990,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
for datasource in datasources:
|
||||
merge_pv("datasource_access", datasource.get_perm())
|
||||
merge_pv("schema_access", datasource.get_schema_perm())
|
||||
merge_pv("catalog_access", datasource.get_catalog_perm())
|
||||
|
||||
logger.info("Creating missing database permissions.")
|
||||
databases = self.get_session.query(models.Database).all()
|
||||
|
|
@ -1212,7 +1359,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
self.get_session.query(self.permissionview_model)
|
||||
.join(self.permission_model)
|
||||
.join(self.viewmenu_model)
|
||||
.filter(self.permission_model.name == "schema_access")
|
||||
.filter(
|
||||
or_(
|
||||
self.permission_model.name == "schema_access",
|
||||
self.permission_model.name == "catalog_access",
|
||||
)
|
||||
)
|
||||
.filter(self.viewmenu_model.name.like(f"[{database_name}].[%]"))
|
||||
.all()
|
||||
)
|
||||
|
|
@ -1399,18 +1551,43 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
.values(perm=dataset_perm)
|
||||
)
|
||||
|
||||
# update catalog and schema perms
|
||||
values: dict[str, Optional[str]] = {}
|
||||
|
||||
if target.schema:
|
||||
dataset_schema_perm = self.get_schema_perm(
|
||||
database.database_name, target.schema
|
||||
database.database_name,
|
||||
target.catalog,
|
||||
target.schema,
|
||||
)
|
||||
self._insert_pvm_on_sqla_event(
|
||||
mapper, connection, "schema_access", dataset_schema_perm
|
||||
mapper,
|
||||
connection,
|
||||
"schema_access",
|
||||
dataset_schema_perm,
|
||||
)
|
||||
target.schema_perm = dataset_schema_perm
|
||||
values["schema_perm"] = dataset_schema_perm
|
||||
|
||||
if target.catalog:
|
||||
dataset_catalog_perm = self.get_catalog_perm(
|
||||
database.database_name,
|
||||
target.catalog,
|
||||
)
|
||||
self._insert_pvm_on_sqla_event(
|
||||
mapper,
|
||||
connection,
|
||||
"catalog_access",
|
||||
dataset_catalog_perm,
|
||||
)
|
||||
target.catalog_perm = dataset_catalog_perm
|
||||
values["catalog_perm"] = dataset_catalog_perm
|
||||
|
||||
if values:
|
||||
connection.execute(
|
||||
dataset_table.update()
|
||||
.where(dataset_table.c.id == target.id)
|
||||
.values(schema_perm=dataset_schema_perm)
|
||||
.values(**values)
|
||||
)
|
||||
|
||||
def dataset_after_delete(
|
||||
|
|
@ -1467,6 +1644,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
table.select().where(table.c.id == target.id)
|
||||
).one()
|
||||
current_db_id = current_dataset.database_id
|
||||
current_catalog = current_dataset.catalog
|
||||
current_schema = current_dataset.schema
|
||||
current_table_name = current_dataset.table_name
|
||||
|
||||
|
|
@ -1479,14 +1657,21 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
mapper, connection, target.perm, new_dataset_vm_name, target
|
||||
)
|
||||
|
||||
# Updates schema permissions
|
||||
new_dataset_schema_name = self.get_schema_perm(
|
||||
target.database.database_name, target.schema
|
||||
# Updates catalog/schema permissions
|
||||
dataset_catalog_name = self.get_catalog_perm(
|
||||
target.database.database_name,
|
||||
target.catalog,
|
||||
)
|
||||
self._update_dataset_schema_perm(
|
||||
dataset_schema_name = self.get_schema_perm(
|
||||
target.database.database_name,
|
||||
target.catalog,
|
||||
target.schema,
|
||||
)
|
||||
self._update_dataset_catalog_schema_perm(
|
||||
mapper,
|
||||
connection,
|
||||
new_dataset_schema_name,
|
||||
dataset_catalog_name,
|
||||
dataset_schema_name,
|
||||
target,
|
||||
)
|
||||
|
||||
|
|
@ -1502,23 +1687,32 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
mapper, connection, old_dataset_vm_name, new_dataset_vm_name, target
|
||||
)
|
||||
|
||||
# When schema changes
|
||||
if current_schema != target.schema:
|
||||
new_dataset_schema_name = self.get_schema_perm(
|
||||
target.database.database_name, target.schema
|
||||
# When catalog/schema change
|
||||
if current_catalog != target.catalog or current_schema != target.schema:
|
||||
dataset_catalog_name = self.get_catalog_perm(
|
||||
target.database.database_name,
|
||||
target.catalog,
|
||||
)
|
||||
self._update_dataset_schema_perm(
|
||||
dataset_schema_name = self.get_schema_perm(
|
||||
target.database.database_name,
|
||||
target.catalog,
|
||||
target.schema,
|
||||
)
|
||||
self._update_dataset_catalog_schema_perm(
|
||||
mapper,
|
||||
connection,
|
||||
new_dataset_schema_name,
|
||||
dataset_catalog_name,
|
||||
dataset_schema_name,
|
||||
target,
|
||||
)
|
||||
|
||||
def _update_dataset_schema_perm(
|
||||
# pylint: disable=invalid-name, too-many-arguments
|
||||
def _update_dataset_catalog_schema_perm(
|
||||
self,
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
new_schema_permission_name: Optional[str],
|
||||
catalog_permission_name: Optional[str],
|
||||
schema_permission_name: Optional[str],
|
||||
target: "SqlaTable",
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -1530,11 +1724,11 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
|
||||
:param mapper: The SQLA event mapper
|
||||
:param connection: The SQLA connection
|
||||
:param new_schema_permission_name: The new schema permission name that changed
|
||||
:param catalog_permission_name: The new catalog permission name that changed
|
||||
:param schema_permission_name: The new schema permission name that changed
|
||||
:param target: Dataset that was updated
|
||||
:return:
|
||||
"""
|
||||
logger.info("Updating schema perm, new: %s", new_schema_permission_name)
|
||||
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
|
||||
SqlaTable,
|
||||
)
|
||||
|
|
@ -1545,18 +1739,30 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
sqlatable_table = SqlaTable.__table__ # pylint: disable=no-member
|
||||
chart_table = Slice.__table__ # pylint: disable=no-member
|
||||
|
||||
# insert new schema PVM if it does not exist
|
||||
# insert new PVMs if they don't not exist
|
||||
self._insert_pvm_on_sqla_event(
|
||||
mapper, connection, "schema_access", new_schema_permission_name
|
||||
mapper,
|
||||
connection,
|
||||
"catalog_access",
|
||||
catalog_permission_name,
|
||||
)
|
||||
self._insert_pvm_on_sqla_event(
|
||||
mapper,
|
||||
connection,
|
||||
"schema_access",
|
||||
schema_permission_name,
|
||||
)
|
||||
|
||||
# Update dataset (SqlaTable schema_perm field)
|
||||
# Update dataset
|
||||
connection.execute(
|
||||
sqlatable_table.update()
|
||||
.where(
|
||||
sqlatable_table.c.id == target.id,
|
||||
)
|
||||
.values(schema_perm=new_schema_permission_name)
|
||||
.values(
|
||||
catalog_perm=catalog_permission_name,
|
||||
schema_perm=schema_permission_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Update charts (Slice schema_perm field)
|
||||
|
|
@ -1566,7 +1772,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
chart_table.c.datasource_id == target.id,
|
||||
chart_table.c.datasource_type == DatasourceType.TABLE,
|
||||
)
|
||||
.values(schema_perm=new_schema_permission_name)
|
||||
.values(
|
||||
catalog_perm=catalog_permission_name,
|
||||
schema_perm=schema_permission_name,
|
||||
)
|
||||
)
|
||||
|
||||
def _update_dataset_perm( # pylint: disable=too-many-arguments
|
||||
|
|
@ -1923,7 +2132,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
table: Optional["Table"] = None,
|
||||
viz: Optional["BaseViz"] = None,
|
||||
sql: Optional[str] = None,
|
||||
catalog: Optional[str] = None, # pylint: disable=unused-argument
|
||||
catalog: Optional[str] = None,
|
||||
schema: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -1947,7 +2156,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils.core import shortid
|
||||
|
||||
if sql and database:
|
||||
|
|
@ -1955,6 +2163,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
database=database,
|
||||
sql=sql,
|
||||
schema=schema,
|
||||
catalog=catalog,
|
||||
client_id=shortid()[:10],
|
||||
user_id=get_user_id(),
|
||||
)
|
||||
|
|
@ -1970,9 +2179,23 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
return
|
||||
|
||||
if query:
|
||||
# Getting the default schema for a query is hard. Users can select the
|
||||
# schema in SQL Lab, but there's no guarantee that the query actually
|
||||
# will run in that schema. Each DB engine spec needs to implement the
|
||||
# necessary logic to enforce that the query runs in the selected schema.
|
||||
# If the DB engine spec doesn't implement the logic the schema is read
|
||||
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
|
||||
# inspector to read it.
|
||||
default_schema = database.get_default_schema_for_query(query)
|
||||
# Determining the default catalog is much easier, because DB engine
|
||||
# specs need explicit support for catalogs.
|
||||
default_catalog = database.get_default_catalog()
|
||||
tables = {
|
||||
Table(table_.table, table_.schema or default_schema)
|
||||
Table(
|
||||
table_.table,
|
||||
table_.schema or default_schema,
|
||||
table_.catalog or default_catalog,
|
||||
)
|
||||
for table_ in extract_tables_from_jinja_sql(query.sql, database)
|
||||
}
|
||||
elif table:
|
||||
|
|
@ -1981,21 +2204,36 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
denied = set()
|
||||
|
||||
for table_ in tables:
|
||||
schema_perm = self.get_schema_perm(database, schema=table_.schema)
|
||||
catalog_perm = self.get_catalog_perm(
|
||||
database.database_name,
|
||||
table_.catalog,
|
||||
)
|
||||
if catalog_perm and self.can_access("catalog_access", catalog_perm):
|
||||
continue
|
||||
|
||||
if not (schema_perm and self.can_access("schema_access", schema_perm)):
|
||||
datasources = SqlaTable.query_datasources_by_name(
|
||||
database, table_.table, schema=table_.schema
|
||||
)
|
||||
schema_perm = self.get_schema_perm(
|
||||
database,
|
||||
table_.catalog,
|
||||
table_.schema,
|
||||
)
|
||||
if schema_perm and self.can_access("schema_access", schema_perm):
|
||||
continue
|
||||
|
||||
# Access to any datasource is suffice.
|
||||
for datasource_ in datasources:
|
||||
if self.can_access(
|
||||
"datasource_access", datasource_.perm
|
||||
) or self.is_owner(datasource_):
|
||||
break
|
||||
else:
|
||||
denied.add(table_)
|
||||
datasources = SqlaTable.query_datasources_by_name(
|
||||
database,
|
||||
table_.table,
|
||||
schema=table_.schema,
|
||||
catalog=table_.catalog,
|
||||
)
|
||||
for datasource_ in datasources:
|
||||
if self.can_access(
|
||||
"datasource_access",
|
||||
datasource_.perm,
|
||||
) or self.is_owner(datasource_):
|
||||
# access to any datasource is sufficient
|
||||
break
|
||||
else:
|
||||
denied.add(table_)
|
||||
|
||||
if denied:
|
||||
raise SupersetSecurityException(
|
||||
|
|
|
|||
|
|
@ -119,7 +119,11 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[...,
|
|||
|
||||
def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapped_f(*args: Any, **kwargs: Any) -> Any:
|
||||
if not kwargs.get("cache", True):
|
||||
should_cache = kwargs.pop("cache", True)
|
||||
force = kwargs.pop("force", False)
|
||||
cache_timeout = kwargs.pop("cache_timeout", 0)
|
||||
|
||||
if not should_cache:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# format the key using args/kwargs passed to the decorated function
|
||||
|
|
@ -129,10 +133,10 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[...,
|
|||
cache_key = key.format(**bound_args.arguments)
|
||||
|
||||
obj = cache.get(cache_key)
|
||||
if not kwargs.get("force") and obj is not None:
|
||||
if not force and obj is not None:
|
||||
return obj
|
||||
obj = f(*args, **kwargs)
|
||||
cache.set(cache_key, obj, timeout=kwargs.get("cache_timeout", 0))
|
||||
cache.set(cache_key, obj, timeout=cache_timeout)
|
||||
return obj
|
||||
|
||||
return wrapped_f
|
||||
|
|
|
|||
|
|
@ -679,11 +679,13 @@ def generic_find_uq_constraint_name(
|
|||
|
||||
|
||||
def get_datasource_full_name(
|
||||
database_name: str, datasource_name: str, schema: str | None = None
|
||||
database_name: str,
|
||||
datasource_name: str,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> str:
|
||||
if not schema:
|
||||
return f"[{database_name}].[{datasource_name}]"
|
||||
return f"[{database_name}].[{schema}].[{datasource_name}]"
|
||||
parts = [database_name, catalog, schema, datasource_name]
|
||||
return ".".join([f"[{part}]" for part in parts if part])
|
||||
|
||||
|
||||
def validate_json(obj: bytes | bytearray | str) -> None:
|
||||
|
|
@ -1051,8 +1053,8 @@ def merge_extra_form_data(form_data: dict[str, Any]) -> None:
|
|||
"adhoc_filters", []
|
||||
)
|
||||
adhoc_filters.extend(
|
||||
{"isExtra": True, **fltr} # type: ignore
|
||||
for fltr in append_adhoc_filters
|
||||
{"isExtra": True, **adhoc_filter} # type: ignore
|
||||
for adhoc_filter in append_adhoc_filters
|
||||
)
|
||||
if append_filters:
|
||||
for key, value in form_data.items():
|
||||
|
|
@ -1502,6 +1504,7 @@ def shortid() -> str:
|
|||
class DatasourceName(NamedTuple):
|
||||
table: str
|
||||
schema: str
|
||||
catalog: str | None = None
|
||||
|
||||
|
||||
def get_stacktrace() -> str | None:
|
||||
|
|
|
|||
|
|
@ -32,10 +32,12 @@ def get_dataset_access_filters(
|
|||
database_ids = security_manager.get_accessible_databases()
|
||||
perms = security_manager.user_view_menu_names("datasource_access")
|
||||
schema_perms = security_manager.user_view_menu_names("schema_access")
|
||||
catalog_perms = security_manager.user_view_menu_names("catalog_access")
|
||||
|
||||
return or_(
|
||||
Database.id.in_(database_ids),
|
||||
base_model.perm.in_(perms),
|
||||
base_model.catalog_perm.in_(catalog_perms),
|
||||
base_model.schema_perm.in_(schema_perms),
|
||||
*args,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -211,11 +211,29 @@ class DatabaseMixin:
|
|||
utils.parse_ssl_cert(database.server_cert)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
security_manager.add_permission_view_menu("database_access", database.perm)
|
||||
# adding a new database we always want to force refresh schema list
|
||||
for schema in database.get_all_schema_names():
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||
)
|
||||
|
||||
# add catalog/schema permissions
|
||||
if database.db_engine_spec.supports_catalog:
|
||||
catalogs = database.get_all_catalog_names()
|
||||
for catalog in catalogs:
|
||||
security_manager.add_permission_view_menu(
|
||||
"catalog_access",
|
||||
security_manager.get_catalog_perm(database.database_name, catalog),
|
||||
)
|
||||
else:
|
||||
# add a dummy catalog for DBs that don't support them
|
||||
catalogs = [None]
|
||||
|
||||
for catalog in catalogs:
|
||||
for schema in database.get_all_schema_names(catalog=catalog):
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access",
|
||||
security_manager.get_schema_perm(
|
||||
database.database_name,
|
||||
catalog,
|
||||
schema,
|
||||
),
|
||||
)
|
||||
|
||||
def pre_add(self, database: Database) -> None:
|
||||
self._pre_add_update(database)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ from sqlalchemy.exc import DBAPIError
|
|||
from sqlalchemy.sql import func
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError # noqa: F401
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe # noqa: F401
|
||||
|
|
@ -296,14 +295,14 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
def test_create_database_with_ssh_tunnel(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_test_connection_database_command_run,
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
|
|
@ -344,14 +343,14 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
def test_create_database_with_missing_port_raises_error(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_test_connection_database_command_run,
|
||||
):
|
||||
"""
|
||||
Database API: Test that missing port raises SSHTunnelDatabaseError
|
||||
|
|
@ -397,15 +396,15 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch("superset.commands.database.update.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
def test_update_database_with_ssh_tunnel(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_test_connection_database_command_run,
|
||||
):
|
||||
"""
|
||||
Database API: Test update Database with SSH Tunnel
|
||||
|
|
@ -458,15 +457,15 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch("superset.commands.database.update.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
def test_update_database_with_missing_port_raises_error(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_test_connection_database_command_run,
|
||||
):
|
||||
"""
|
||||
Database API: Test that missing port raises SSHTunnelDatabaseError
|
||||
|
|
@ -523,16 +522,16 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch("superset.commands.database.update.is_feature_enabled")
|
||||
@mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
def test_delete_ssh_tunnel(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_delete_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_delete_is_feature_enabled,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_test_connection_database_command_run,
|
||||
):
|
||||
"""
|
||||
Database API: Test deleting a SSH tunnel via Database update
|
||||
|
|
@ -606,15 +605,15 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch("superset.commands.database.update.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
def test_update_ssh_tunnel_via_database_api(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_test_connection_database_command_run,
|
||||
):
|
||||
"""
|
||||
Database API: Test update SSH Tunnel via Database API
|
||||
|
|
@ -684,15 +683,15 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
@mock.patch(
|
||||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
def test_cascade_delete_ssh_tunnel(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_get_all_schema_names,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_test_connection_database_command_run,
|
||||
):
|
||||
"""
|
||||
Database API: SSH Tunnel gets deleted if Database gets deleted
|
||||
|
|
@ -739,16 +738,16 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
@mock.patch("superset.extensions.db.session.rollback")
|
||||
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
|
||||
self,
|
||||
mock_rollback,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_rollback,
|
||||
):
|
||||
"""
|
||||
Database API: Test rollback is called if SSH Tunnel creation fails
|
||||
|
|
@ -788,14 +787,14 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.commands.database.create.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
def test_get_database_returns_related_ssh_tunnel(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_test_connection_database_command_run,
|
||||
):
|
||||
"""
|
||||
Database API: Test GET Database returns its related SSH Tunnel
|
||||
|
|
@ -842,13 +841,13 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
@mock.patch(
|
||||
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.models.core.Database.get_all_catalog_names")
|
||||
@mock.patch("superset.models.core.Database.get_all_schema_names")
|
||||
def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_get_all_schema_names,
|
||||
mock_get_all_catalog_names,
|
||||
mock_test_connection_database_command_run,
|
||||
):
|
||||
"""
|
||||
Database API: Test raises SSHTunneling feature flag not enabled
|
||||
|
|
@ -1987,13 +1986,13 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
|
||||
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(schemas, response["result"])
|
||||
self.assertEqual(schemas, set(response["result"]))
|
||||
|
||||
rv = self.client.get(
|
||||
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
|
||||
)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(schemas, response["result"])
|
||||
self.assertEqual(schemas, set(response["result"]))
|
||||
|
||||
def test_database_schemas_not_found(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1095,7 +1095,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
|
|||
@patch("superset.daos.database.DatabaseDAO.find_by_id")
|
||||
def test_database_tables_list_with_unknown_database(self, mock_find_by_id):
|
||||
mock_find_by_id.return_value = None
|
||||
command = TablesDatabaseCommand(1, "test", False)
|
||||
command = TablesDatabaseCommand(1, None, "test", False)
|
||||
|
||||
with pytest.raises(DatabaseNotFoundError) as excinfo:
|
||||
command.run()
|
||||
|
|
@ -1115,7 +1115,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
|
|||
mock_can_access_database.side_effect = SupersetException("Test Error")
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
|
||||
command = TablesDatabaseCommand(database.id, "main", False)
|
||||
command = TablesDatabaseCommand(database.id, None, "main", False)
|
||||
with pytest.raises(SupersetException) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Test Error"
|
||||
|
|
@ -1131,7 +1131,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
|
|||
mock_can_access_database.side_effect = Exception("Test Error")
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
|
||||
command = TablesDatabaseCommand(database.id, "main", False)
|
||||
command = TablesDatabaseCommand(database.id, None, "main", False)
|
||||
with pytest.raises(DatabaseTablesUnexpectedError) as excinfo:
|
||||
command.run()
|
||||
assert (
|
||||
|
|
@ -1154,7 +1154,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
|
|||
if database.backend == "postgresql" or database.backend == "mysql":
|
||||
return
|
||||
|
||||
command = TablesDatabaseCommand(database.id, schema_name, False)
|
||||
command = TablesDatabaseCommand(database.id, None, schema_name, False)
|
||||
result = command.run()
|
||||
|
||||
assert result["count"] > 0
|
||||
|
|
|
|||
|
|
@ -531,7 +531,7 @@ def test_get_catalog_names(app_context: AppContext) -> None:
|
|||
return
|
||||
|
||||
with database.get_inspector() as inspector:
|
||||
assert PostgresEngineSpec.get_catalog_names(database, inspector) == [
|
||||
assert PostgresEngineSpec.get_catalog_names(database, inspector) == {
|
||||
"postgres",
|
||||
"superset",
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -343,20 +343,20 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
main_db = get_example_database()
|
||||
|
||||
if main_db.backend == "mysql":
|
||||
df = main_db.get_df("SELECT 1", None)
|
||||
df = main_db.get_df("SELECT 1", None, None)
|
||||
self.assertEqual(df.iat[0, 0], 1)
|
||||
|
||||
df = main_db.get_df("SELECT 1;", None)
|
||||
df = main_db.get_df("SELECT 1;", None, None)
|
||||
self.assertEqual(df.iat[0, 0], 1)
|
||||
|
||||
def test_multi_statement(self):
|
||||
main_db = get_example_database()
|
||||
|
||||
if main_db.backend == "mysql":
|
||||
df = main_db.get_df("USE superset; SELECT 1", None)
|
||||
df = main_db.get_df("USE superset; SELECT 1", None, None)
|
||||
self.assertEqual(df.iat[0, 0], 1)
|
||||
|
||||
df = main_db.get_df("USE superset; SELECT ';';", None)
|
||||
df = main_db.get_df("USE superset; SELECT ';';", None, None)
|
||||
self.assertEqual(df.iat[0, 0], ";")
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
|
|
|
|||
|
|
@ -898,7 +898,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one()
|
||||
)
|
||||
self.assertEqual(changed_table1.perm, f"[tmp_db2].[tmp_table1](id:{table1.id})")
|
||||
self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541
|
||||
self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541
|
||||
|
||||
# Test Chart permission changed
|
||||
slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one()
|
||||
|
|
@ -956,12 +956,12 @@ class TestRolePermission(SupersetTestCase):
|
|||
db.session.query(SqlaTable).filter_by(table_name="tmp_table1").one()
|
||||
)
|
||||
self.assertEqual(changed_table1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})")
|
||||
self.assertEqual(changed_table1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # noqa: F541
|
||||
self.assertEqual(changed_table1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541
|
||||
|
||||
# Test Chart schema permission changed
|
||||
slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one()
|
||||
self.assertEqual(slice1.perm, f"[tmp_db1].[tmp_table1](id:{table1.id})")
|
||||
self.assertEqual(slice1.schema_perm, f"[tmp_db1].[tmp_schema_changed]") # noqa: F541
|
||||
self.assertEqual(slice1.schema_perm, "[tmp_db1].[tmp_schema_changed]") # noqa: F541
|
||||
|
||||
# cleanup
|
||||
db.session.delete(slice1)
|
||||
|
|
@ -1069,7 +1069,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
self.assertEqual(
|
||||
changed_table1.perm, f"[tmp_db2].[tmp_table1_changed](id:{table1.id})"
|
||||
)
|
||||
self.assertEqual(changed_table1.schema_perm, f"[tmp_db2].[tmp_schema]") # noqa: F541
|
||||
self.assertEqual(changed_table1.schema_perm, "[tmp_db2].[tmp_schema]") # noqa: F541
|
||||
|
||||
# Test Chart permission changed
|
||||
slice1 = db.session.query(Slice).filter_by(slice_name="tmp_slice1").one()
|
||||
|
|
@ -1158,9 +1158,9 @@ class TestRolePermission(SupersetTestCase):
|
|||
with self.client.application.test_request_context():
|
||||
database = get_example_database()
|
||||
schemas = security_manager.get_schemas_accessible_by_user(
|
||||
database, ["1", "2", "3"]
|
||||
database, None, {"1", "2", "3"}
|
||||
)
|
||||
self.assertEqual(schemas, ["1", "2", "3"]) # no changes
|
||||
self.assertEqual(schemas, {"1", "2", "3"}) # no changes
|
||||
|
||||
@patch("superset.utils.core.g")
|
||||
@patch("superset.security.manager.g")
|
||||
|
|
@ -1171,10 +1171,10 @@ class TestRolePermission(SupersetTestCase):
|
|||
with self.client.application.test_request_context():
|
||||
database = get_example_database()
|
||||
schemas = security_manager.get_schemas_accessible_by_user(
|
||||
database, ["1", "2", "3"]
|
||||
database, None, {"1", "2", "3"}
|
||||
)
|
||||
# temp_schema is not passed in the params
|
||||
self.assertEqual(schemas, ["1"])
|
||||
self.assertEqual(schemas, {"1"})
|
||||
delete_schema_perm("[examples].[1]")
|
||||
|
||||
def test_schemas_accessible_by_user_datasource_access(self):
|
||||
|
|
@ -1183,9 +1183,9 @@ class TestRolePermission(SupersetTestCase):
|
|||
with self.client.application.test_request_context():
|
||||
with override_user(security_manager.find_user("gamma")):
|
||||
schemas = security_manager.get_schemas_accessible_by_user(
|
||||
database, ["temp_schema", "2", "3"]
|
||||
database, None, {"temp_schema", "2", "3"}
|
||||
)
|
||||
self.assertEqual(schemas, ["temp_schema"])
|
||||
self.assertEqual(schemas, {"temp_schema"})
|
||||
|
||||
def test_schemas_accessible_by_user_datasource_and_schema_access(self):
|
||||
# User has schema access to the datasource temp_schema.wb_health_population in examples DB.
|
||||
|
|
@ -1194,9 +1194,9 @@ class TestRolePermission(SupersetTestCase):
|
|||
database = get_example_database()
|
||||
with override_user(security_manager.find_user("gamma")):
|
||||
schemas = security_manager.get_schemas_accessible_by_user(
|
||||
database, ["temp_schema", "2", "3"]
|
||||
database, None, {"temp_schema", "2", "3"}
|
||||
)
|
||||
self.assertEqual(schemas, ["temp_schema", "2"])
|
||||
self.assertEqual(schemas, {"temp_schema", "2"})
|
||||
vm = security_manager.find_permission_view_menu(
|
||||
"schema_access", "[examples].[2]"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -284,9 +284,15 @@ class TestSqlLab(SupersetTestCase):
|
|||
# sqlite doesn't support database creation
|
||||
return
|
||||
|
||||
catalog = examples_db.get_default_catalog()
|
||||
sqllab_test_db_schema_permission_view = (
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access", f"[{examples_db.name}].[{CTAS_SCHEMA_NAME}]"
|
||||
"schema_access",
|
||||
security_manager.get_schema_perm(
|
||||
examples_db.name,
|
||||
catalog,
|
||||
CTAS_SCHEMA_NAME,
|
||||
),
|
||||
)
|
||||
)
|
||||
schema_perm_role = security_manager.add_role("SchemaPermission")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,128 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.commands.database.create import CreateDatabaseCommand
|
||||
from superset.extensions import security_manager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
||||
"""
|
||||
Mock a database with catalogs and schemas.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.create.db")
|
||||
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.database_name = "test_database"
|
||||
database.db_engine_spec.__name__ = "test_engine"
|
||||
database.db_engine_spec.supports_catalog = True
|
||||
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
|
||||
database.get_all_schema_names.side_effect = [
|
||||
{"schema1", "schema2"},
|
||||
{"schema3", "schema4"},
|
||||
]
|
||||
|
||||
DatabaseDAO = mocker.patch("superset.commands.database.create.DatabaseDAO")
|
||||
DatabaseDAO.create.return_value = database
|
||||
|
||||
return database
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def database_without_catalog(mocker: MockerFixture) -> MagicMock:
|
||||
"""
|
||||
Mock a database without catalogs.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.create.db")
|
||||
mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.database_name = "test_database"
|
||||
database.db_engine_spec.__name__ = "test_engine"
|
||||
database.db_engine_spec.supports_catalog = False
|
||||
database.get_all_schema_names.return_value = ["schema1", "schema2"]
|
||||
|
||||
DatabaseDAO = mocker.patch("superset.commands.database.create.DatabaseDAO")
|
||||
DatabaseDAO.create.return_value = database
|
||||
|
||||
return database
|
||||
|
||||
|
||||
def test_create_permissions_with_catalog(
|
||||
mocker: MockerFixture,
|
||||
database_with_catalog: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that permissions are created when a database with a catalog is created.
|
||||
"""
|
||||
add_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"add_permission_view_menu",
|
||||
)
|
||||
|
||||
CreateDatabaseCommand(
|
||||
{
|
||||
"database_name": "test_database",
|
||||
"sqlalchemy_uri": "sqlite://",
|
||||
}
|
||||
).run()
|
||||
|
||||
add_permission_view_menu.assert_has_calls(
|
||||
[
|
||||
mocker.call("catalog_access", "[test_database].[catalog1]"),
|
||||
mocker.call("catalog_access", "[test_database].[catalog2]"),
|
||||
mocker.call("schema_access", "[test_database].[catalog1].[schema1]"),
|
||||
mocker.call("schema_access", "[test_database].[catalog1].[schema2]"),
|
||||
mocker.call("schema_access", "[test_database].[catalog2].[schema3]"),
|
||||
mocker.call("schema_access", "[test_database].[catalog2].[schema4]"),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
||||
|
||||
def test_create_permissions_without_catalog(
|
||||
mocker: MockerFixture,
|
||||
database_without_catalog: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that permissions are created when a database without a catalog is created.
|
||||
"""
|
||||
add_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"add_permission_view_menu",
|
||||
)
|
||||
|
||||
CreateDatabaseCommand(
|
||||
{
|
||||
"database_name": "test_database",
|
||||
"sqlalchemy_uri": "sqlite://",
|
||||
}
|
||||
).run()
|
||||
|
||||
add_permission_view_menu.assert_has_calls(
|
||||
[
|
||||
mocker.call("schema_access", "[test_database].[schema1]"),
|
||||
mocker.call("schema_access", "[test_database].[schema2]"),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.commands.database.tables import TablesDatabaseCommand
|
||||
from superset.extensions import security_manager
|
||||
from superset.utils.core import DatasourceName
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
||||
"""
|
||||
Mock a database with catalogs and schemas.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.tables.db")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.database_name = "test_database"
|
||||
database.get_all_table_names_in_schema.return_value = [
|
||||
DatasourceName("table1", "schema1", "catalog1"),
|
||||
DatasourceName("table2", "schema1", "catalog1"),
|
||||
]
|
||||
database.get_all_view_names_in_schema.return_value = [
|
||||
DatasourceName("view1", "schema1", "catalog1"),
|
||||
]
|
||||
|
||||
DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO")
|
||||
DatabaseDAO.find_by_id.return_value = database
|
||||
|
||||
return database
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def database_without_catalog(mocker: MockerFixture) -> MagicMock:
|
||||
"""
|
||||
Mock a database without catalogs but with schemas.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.tables.db")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.database_name = "test_database"
|
||||
database.get_all_table_names_in_schema.return_value = [
|
||||
DatasourceName("table1", "schema1"),
|
||||
DatasourceName("table2", "schema1"),
|
||||
]
|
||||
database.get_all_view_names_in_schema.return_value = [
|
||||
DatasourceName("view1", "schema1"),
|
||||
]
|
||||
|
||||
DatabaseDAO = mocker.patch("superset.commands.database.tables.DatabaseDAO")
|
||||
DatabaseDAO.find_by_id.return_value = database
|
||||
|
||||
return database
|
||||
|
||||
|
||||
def test_tables_with_catalog(
|
||||
mocker: MockerFixture,
|
||||
database_with_catalog: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that permissions are created when a database with a catalog is created.
|
||||
"""
|
||||
get_datasources_accessible_by_user = mocker.patch.object(
|
||||
security_manager,
|
||||
"get_datasources_accessible_by_user",
|
||||
side_effect=[
|
||||
{
|
||||
DatasourceName("table1", "schema1", "catalog1"),
|
||||
DatasourceName("table2", "schema1", "catalog1"),
|
||||
},
|
||||
{DatasourceName("view1", "schema1", "catalog1")},
|
||||
],
|
||||
)
|
||||
|
||||
db = mocker.patch("superset.commands.database.tables.db")
|
||||
table = mocker.MagicMock()
|
||||
table.name = "table1"
|
||||
table.extra_dict = {"foo": "bar"}
|
||||
db.session.query().filter().options().all.return_value = [table]
|
||||
|
||||
payload = TablesDatabaseCommand(1, "catalog1", "schema1", False).run()
|
||||
assert payload == {
|
||||
"count": 3,
|
||||
"result": [
|
||||
{"value": "table1", "type": "table", "extra": {"foo": "bar"}},
|
||||
{"value": "table2", "type": "table", "extra": None},
|
||||
{"value": "view1", "type": "view"},
|
||||
],
|
||||
}
|
||||
|
||||
get_datasources_accessible_by_user.assert_has_calls(
|
||||
[
|
||||
mocker.call(
|
||||
database=database_with_catalog,
|
||||
catalog="catalog1",
|
||||
schema="schema1",
|
||||
datasource_names=[
|
||||
DatasourceName("table1", "schema1", "catalog1"),
|
||||
DatasourceName("table2", "schema1", "catalog1"),
|
||||
],
|
||||
),
|
||||
mocker.call(
|
||||
database=database_with_catalog,
|
||||
catalog="catalog1",
|
||||
schema="schema1",
|
||||
datasource_names=[
|
||||
DatasourceName("view1", "schema1", "catalog1"),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
database_with_catalog.get_all_table_names_in_schema.assert_called_with(
|
||||
catalog="catalog1",
|
||||
schema="schema1",
|
||||
force=False,
|
||||
cache=database_with_catalog.table_cache_enabled,
|
||||
cache_timeout=database_with_catalog.table_cache_timeout,
|
||||
)
|
||||
|
||||
|
||||
def test_tables_without_catalog(
|
||||
mocker: MockerFixture,
|
||||
database_without_catalog: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that permissions are created when a database without a catalog is created.
|
||||
"""
|
||||
get_datasources_accessible_by_user = mocker.patch.object(
|
||||
security_manager,
|
||||
"get_datasources_accessible_by_user",
|
||||
side_effect=[
|
||||
{
|
||||
DatasourceName("table1", "schema1"),
|
||||
DatasourceName("table2", "schema1"),
|
||||
},
|
||||
{DatasourceName("view1", "schema1")},
|
||||
],
|
||||
)
|
||||
|
||||
db = mocker.patch("superset.commands.database.tables.db")
|
||||
table = mocker.MagicMock()
|
||||
table.name = "table1"
|
||||
table.extra_dict = {"foo": "bar"}
|
||||
db.session.query().filter().options().all.return_value = [table]
|
||||
|
||||
payload = TablesDatabaseCommand(1, None, "schema1", False).run()
|
||||
assert payload == {
|
||||
"count": 3,
|
||||
"result": [
|
||||
{"value": "table1", "type": "table", "extra": {"foo": "bar"}},
|
||||
{"value": "table2", "type": "table", "extra": None},
|
||||
{"value": "view1", "type": "view"},
|
||||
],
|
||||
}
|
||||
|
||||
get_datasources_accessible_by_user.assert_has_calls(
|
||||
[
|
||||
mocker.call(
|
||||
database=database_without_catalog,
|
||||
catalog=None,
|
||||
schema="schema1",
|
||||
datasource_names=[
|
||||
DatasourceName("table1", "schema1"),
|
||||
DatasourceName("table2", "schema1"),
|
||||
],
|
||||
),
|
||||
mocker.call(
|
||||
database=database_without_catalog,
|
||||
catalog=None,
|
||||
schema="schema1",
|
||||
datasource_names=[
|
||||
DatasourceName("view1", "schema1"),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
database_without_catalog.get_all_table_names_in_schema.assert_called_with(
|
||||
catalog=None,
|
||||
schema="schema1",
|
||||
force=False,
|
||||
cache=database_without_catalog.table_cache_enabled,
|
||||
cache_timeout=database_without_catalog.table_cache_timeout,
|
||||
)
|
||||
|
|
@ -0,0 +1,272 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.commands.database.update import UpdateDatabaseCommand
|
||||
from superset.extensions import security_manager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def database_with_catalog(mocker: MockerFixture) -> MagicMock:
|
||||
"""
|
||||
Mock a database with catalogs and schemas.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.update.db")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.database_name = "my_db"
|
||||
database.db_engine_spec.__name__ = "test_engine"
|
||||
database.db_engine_spec.supports_catalog = True
|
||||
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
|
||||
database.get_all_schema_names.side_effect = [
|
||||
["schema1", "schema2"],
|
||||
["schema3", "schema4"],
|
||||
]
|
||||
database.get_default_catalog.return_value = "catalog2"
|
||||
|
||||
return database
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def database_without_catalog(mocker: MockerFixture) -> MagicMock:
|
||||
"""
|
||||
Mock a database without catalogs.
|
||||
"""
|
||||
mocker.patch("superset.commands.database.update.db")
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.database_name = "my_db"
|
||||
database.db_engine_spec.__name__ = "test_engine"
|
||||
database.db_engine_spec.supports_catalog = False
|
||||
database.get_all_schema_names.return_value = ["schema1", "schema2"]
|
||||
|
||||
return database
|
||||
|
||||
|
||||
def test_update_with_catalog(
|
||||
mocker: MockerFixture,
|
||||
database_with_catalog: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that permissions are updated correctly.
|
||||
|
||||
In this test, the database has two catalogs with two schemas each:
|
||||
|
||||
- catalog1
|
||||
- schema1
|
||||
- schema2
|
||||
- catalog2
|
||||
- schema3
|
||||
- schema4
|
||||
|
||||
When update is called, only `catalog2.schema3` has permissions associated with it,
|
||||
so `catalog1.*` and `catalog2.schema4` are added.
|
||||
"""
|
||||
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
|
||||
DatabaseDAO.find_by_id.return_value = database_with_catalog
|
||||
DatabaseDAO.update.return_value = database_with_catalog
|
||||
|
||||
find_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"find_permission_view_menu",
|
||||
)
|
||||
find_permission_view_menu.side_effect = [
|
||||
None, # first catalog is new
|
||||
"[my_db].[catalog2]", # second catalog already exists
|
||||
"[my_db].[catalog2].[schema3]", # first schema already exists
|
||||
None, # second schema is new
|
||||
# these are called when checking for existing perms in [db].[schema] format
|
||||
None,
|
||||
None,
|
||||
]
|
||||
add_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"add_permission_view_menu",
|
||||
)
|
||||
|
||||
UpdateDatabaseCommand(1, {}).run()
|
||||
|
||||
add_permission_view_menu.assert_has_calls(
|
||||
[
|
||||
# first catalog is added with all schemas
|
||||
mocker.call("catalog_access", "[my_db].[catalog1]"),
|
||||
mocker.call("schema_access", "[my_db].[catalog1].[schema1]"),
|
||||
mocker.call("schema_access", "[my_db].[catalog1].[schema2]"),
|
||||
# second catalog already exists, only `schema4` is added
|
||||
mocker.call("schema_access", "[my_db].[catalog2].[schema4]"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_update_without_catalog(
|
||||
mocker: MockerFixture,
|
||||
database_without_catalog: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that permissions are updated correctly.
|
||||
|
||||
In this test, the database has no catalogs and two schemas:
|
||||
|
||||
- schema1
|
||||
- schema2
|
||||
|
||||
When update is called, only `schema2` has permissions associated with it, so `schema1`
|
||||
is added.
|
||||
"""
|
||||
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
|
||||
DatabaseDAO.find_by_id.return_value = database_without_catalog
|
||||
DatabaseDAO.update.return_value = database_without_catalog
|
||||
|
||||
find_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"find_permission_view_menu",
|
||||
)
|
||||
find_permission_view_menu.side_effect = [
|
||||
None, # schema1 has no permissions
|
||||
"[my_db].[schema2]", # second schema already exists
|
||||
]
|
||||
add_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"add_permission_view_menu",
|
||||
)
|
||||
|
||||
UpdateDatabaseCommand(1, {}).run()
|
||||
|
||||
add_permission_view_menu.assert_called_with(
|
||||
"schema_access",
|
||||
"[my_db].[schema1]",
|
||||
)
|
||||
|
||||
|
||||
def test_rename_with_catalog(
|
||||
mocker: MockerFixture,
|
||||
database_with_catalog: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that permissions are renamed correctly.
|
||||
|
||||
In this test, the database has two catalogs with two schemas each:
|
||||
|
||||
- catalog1
|
||||
- schema1
|
||||
- schema2
|
||||
- catalog2
|
||||
- schema3
|
||||
- schema4
|
||||
|
||||
When update is called, only `catalog2.schema3` has permissions associated with it,
|
||||
so `catalog1.*` and `catalog2.schema4` are added. Additionally, the database has
|
||||
been renamed from `my_db` to `my_other_db`.
|
||||
"""
|
||||
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
|
||||
original_database = mocker.MagicMock()
|
||||
original_database.database_name = "my_db"
|
||||
DatabaseDAO.find_by_id.return_value = original_database
|
||||
database_with_catalog.database_name = "my_other_db"
|
||||
DatabaseDAO.update.return_value = database_with_catalog
|
||||
DatabaseDAO.get_datasets.return_value = []
|
||||
|
||||
find_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"find_permission_view_menu",
|
||||
)
|
||||
catalog2_pvm = mocker.MagicMock()
|
||||
catalog2_schema3_pvm = mocker.MagicMock()
|
||||
find_permission_view_menu.side_effect = [
|
||||
# these are called when adding the permissions:
|
||||
None, # first catalog is new
|
||||
"[my_db].[catalog2]", # second catalog already exists
|
||||
"[my_db].[catalog2].[schema3]", # first schema already exists
|
||||
None, # second schema is new
|
||||
# these are called when renaming the permissions:
|
||||
catalog2_pvm, # old [my_db].[catalog2]
|
||||
catalog2_schema3_pvm, # old [my_db].[catalog2].[schema3]
|
||||
None, # [my_db].[catalog2].[schema4] doesn't exist
|
||||
]
|
||||
add_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"add_permission_view_menu",
|
||||
)
|
||||
|
||||
UpdateDatabaseCommand(1, {}).run()
|
||||
|
||||
add_permission_view_menu.assert_has_calls(
|
||||
[
|
||||
# first catalog is added with all schemas with the new DB name
|
||||
mocker.call("catalog_access", "[my_other_db].[catalog1]"),
|
||||
mocker.call("schema_access", "[my_other_db].[catalog1].[schema1]"),
|
||||
mocker.call("schema_access", "[my_other_db].[catalog1].[schema2]"),
|
||||
# second catalog already exists, only `schema4` is added
|
||||
mocker.call("schema_access", "[my_other_db].[catalog2].[schema4]"),
|
||||
],
|
||||
)
|
||||
|
||||
assert catalog2_pvm.view_menu.name == "[my_other_db].[catalog2]"
|
||||
assert catalog2_schema3_pvm.view_menu.name == "[my_other_db].[catalog2].[schema3]"
|
||||
|
||||
|
||||
def test_rename_without_catalog(
|
||||
mocker: MockerFixture,
|
||||
database_without_catalog: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that permissions are renamed correctly.
|
||||
|
||||
In this test, the database has no catalogs and two schemas:
|
||||
|
||||
- schema1
|
||||
- schema2
|
||||
|
||||
When update is called, only `schema2` has permissions associated with it, so `schema1`
|
||||
is added. Additionally, the database has been renamed from `my_db` to `my_other_db`.
|
||||
"""
|
||||
DatabaseDAO = mocker.patch("superset.commands.database.update.DatabaseDAO")
|
||||
original_database = mocker.MagicMock()
|
||||
original_database.database_name = "my_db"
|
||||
DatabaseDAO.find_by_id.return_value = original_database
|
||||
database_without_catalog.database_name = "my_other_db"
|
||||
DatabaseDAO.update.return_value = database_without_catalog
|
||||
DatabaseDAO.get_datasets.return_value = []
|
||||
|
||||
find_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"find_permission_view_menu",
|
||||
)
|
||||
schema2_pvm = mocker.MagicMock()
|
||||
find_permission_view_menu.side_effect = [
|
||||
None, # schema1 has no permissions
|
||||
"[my_db].[schema2]", # second schema already exists
|
||||
None, # [my_db].[schema1] doesn't exist
|
||||
schema2_pvm, # old [my_db].[schema2]
|
||||
]
|
||||
add_permission_view_menu = mocker.patch.object(
|
||||
security_manager,
|
||||
"add_permission_view_menu",
|
||||
)
|
||||
|
||||
UpdateDatabaseCommand(1, {}).run()
|
||||
|
||||
add_permission_view_menu.assert_called_with(
|
||||
"schema_access",
|
||||
"[my_other_db].[schema1]",
|
||||
)
|
||||
|
||||
assert schema2_pvm.view_menu.name == "[my_other_db].[schema2]"
|
||||
|
|
@ -88,6 +88,8 @@ def app(request: SubRequest) -> Iterator[SupersetApp]:
|
|||
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
|
||||
app.config["TESTING"] = True
|
||||
app.config["RATELIMIT_ENABLED"] = False
|
||||
app.config["CACHE_CONFIG"] = {}
|
||||
app.config["DATA_CACHE_CONFIG"] = {}
|
||||
|
||||
# loop over extra configs passed in by tests
|
||||
# and update the app config
|
||||
|
|
|
|||
|
|
@ -17,9 +17,11 @@
|
|||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.exceptions import OAuth2RedirectError
|
||||
from superset.models.core import Database
|
||||
from superset.superset_typing import QueryObjectDict
|
||||
|
||||
|
||||
|
|
@ -64,3 +66,124 @@ def test_query_bubbles_errors(mocker: MockerFixture) -> None:
|
|||
}
|
||||
with pytest.raises(OAuth2RedirectError):
|
||||
sqla_table.query(query_obj)
|
||||
|
||||
|
||||
def test_permissions_without_catalog() -> None:
|
||||
"""
|
||||
Test permissions when the table has no catalog.
|
||||
"""
|
||||
database = Database(database_name="my_db")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=database,
|
||||
schema="schema1",
|
||||
catalog=None,
|
||||
id=1,
|
||||
)
|
||||
|
||||
assert sqla_table.get_perm() == "[my_db].[my_sqla_table](id:1)"
|
||||
assert sqla_table.get_catalog_perm() is None
|
||||
assert sqla_table.get_schema_perm() == "[my_db].[schema1]"
|
||||
|
||||
|
||||
def test_permissions_with_catalog() -> None:
|
||||
"""
|
||||
Test permissions when the table with a catalog set.
|
||||
"""
|
||||
database = Database(database_name="my_db")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=database,
|
||||
schema="schema1",
|
||||
catalog="db1",
|
||||
id=1,
|
||||
)
|
||||
|
||||
assert sqla_table.get_perm() == "[my_db].[my_sqla_table](id:1)"
|
||||
assert sqla_table.get_catalog_perm() == "[my_db].[db1]"
|
||||
assert sqla_table.get_schema_perm() == "[my_db].[db1].[schema1]"
|
||||
|
||||
|
||||
def test_query_datasources_by_name(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the `query_datasources_by_name` method.
|
||||
"""
|
||||
db = mocker.patch("superset.connectors.sqla.models.db")
|
||||
|
||||
database = Database(database_name="my_db", id=1)
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=database,
|
||||
)
|
||||
|
||||
sqla_table.query_datasources_by_name(database, "my_table")
|
||||
db.session.query().filter_by.assert_called_with(
|
||||
database_id=1,
|
||||
table_name="my_table",
|
||||
)
|
||||
|
||||
sqla_table.query_datasources_by_name(database, "my_table", "db1", "schema1")
|
||||
db.session.query().filter_by.assert_called_with(
|
||||
database_id=1,
|
||||
table_name="my_table",
|
||||
catalog="db1",
|
||||
schema="schema1",
|
||||
)
|
||||
|
||||
|
||||
def test_query_datasources_by_permissions(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the `query_datasources_by_permissions` method.
|
||||
"""
|
||||
db = mocker.patch("superset.connectors.sqla.models.db")
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
database = Database(database_name="my_db", id=1)
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=database,
|
||||
)
|
||||
|
||||
sqla_table.query_datasources_by_permissions(database, set(), set(), set())
|
||||
db.session.query().filter_by.assert_called_with(database_id=1)
|
||||
clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
|
||||
assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == ""
|
||||
|
||||
|
||||
def test_query_datasources_by_permissions_with_catalog_schema(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test the `query_datasources_by_permissions` method passing a catalog and schema.
|
||||
"""
|
||||
db = mocker.patch("superset.connectors.sqla.models.db")
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
database = Database(database_name="my_db", id=1)
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=database,
|
||||
)
|
||||
sqla_table.query_datasources_by_permissions(
|
||||
database,
|
||||
{"[my_db].[table1](id:1)"},
|
||||
{"[my_db].[db1]"},
|
||||
# pass as list to have deterministic order for test
|
||||
["[my_db].[db1].[schema1]", "[my_other_db].[schema]"], # type: ignore
|
||||
)
|
||||
clause = db.session.query().filter_by().filter.mock_calls[0].args[0]
|
||||
assert str(clause.compile(engine, compile_kwargs={"literal_binds": True})) == (
|
||||
"tables.perm IN ('[my_db].[table1](id:1)') OR "
|
||||
"tables.schema_perm IN ('[my_db].[db1].[schema1]', '[my_other_db].[schema]') OR "
|
||||
"tables.catalog_perm IN ('[my_db].[db1]')"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2066,3 +2066,101 @@ def test_table_extra_metadata_unauthorized(
|
|||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_catalogs(
|
||||
mocker: MockFixture,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test the `catalogs` endpoint.
|
||||
"""
|
||||
database = mocker.MagicMock()
|
||||
database.get_all_catalog_names.return_value = {"db1", "db2"}
|
||||
DatabaseDAO = mocker.patch("superset.databases.api.DatabaseDAO")
|
||||
DatabaseDAO.find_by_id.return_value = database
|
||||
|
||||
security_manager = mocker.patch(
|
||||
"superset.databases.api.security_manager",
|
||||
new=mocker.MagicMock(),
|
||||
)
|
||||
security_manager.get_catalogs_accessible_by_user.return_value = {"db2"}
|
||||
|
||||
response = client.get("/api/v1/database/1/catalogs/")
|
||||
assert response.status_code == 200
|
||||
assert response.json == {"result": ["db2"]}
|
||||
database.get_all_catalog_names.assert_called_with(
|
||||
cache=database.catalog_cache_enabled,
|
||||
cache_timeout=database.catalog_cache_timeout,
|
||||
force=False,
|
||||
)
|
||||
security_manager.get_catalogs_accessible_by_user.assert_called_with(
|
||||
database,
|
||||
{"db1", "db2"},
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/database/1/catalogs/?q=(force:!t)")
|
||||
database.get_all_catalog_names.assert_called_with(
|
||||
cache=database.catalog_cache_enabled,
|
||||
cache_timeout=database.catalog_cache_timeout,
|
||||
force=True,
|
||||
)
|
||||
|
||||
|
||||
def test_schemas(
|
||||
mocker: MockFixture,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test the `schemas` endpoint.
|
||||
"""
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.get_all_schema_names.return_value = {"schema1", "schema2"}
|
||||
datamodel = mocker.patch.object(DatabaseRestApi, "datamodel")
|
||||
datamodel.get.return_value = database
|
||||
|
||||
security_manager = mocker.patch(
|
||||
"superset.databases.api.security_manager",
|
||||
new=mocker.MagicMock(),
|
||||
)
|
||||
security_manager.get_schemas_accessible_by_user.return_value = {"schema2"}
|
||||
|
||||
response = client.get("/api/v1/database/1/schemas/")
|
||||
assert response.status_code == 200
|
||||
assert response.json == {"result": ["schema2"]}
|
||||
database.get_all_schema_names.assert_called_with(
|
||||
catalog=None,
|
||||
cache=database.schema_cache_enabled,
|
||||
cache_timeout=database.schema_cache_timeout,
|
||||
force=False,
|
||||
)
|
||||
security_manager.get_schemas_accessible_by_user.assert_called_with(
|
||||
database,
|
||||
None,
|
||||
{"schema1", "schema2"},
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/database/1/schemas/?q=(force:!t)")
|
||||
database.get_all_schema_names.assert_called_with(
|
||||
catalog=None,
|
||||
cache=database.schema_cache_enabled,
|
||||
cache_timeout=database.schema_cache_timeout,
|
||||
force=True,
|
||||
)
|
||||
|
||||
response = client.get("/api/v1/database/1/schemas/?q=(force:!t,catalog:catalog2)")
|
||||
database.get_all_schema_names.assert_called_with(
|
||||
catalog="catalog2",
|
||||
cache=database.schema_cache_enabled,
|
||||
cache_timeout=database.schema_cache_timeout,
|
||||
force=True,
|
||||
)
|
||||
security_manager.get_schemas_accessible_by_user.assert_called_with(
|
||||
database,
|
||||
"catalog2",
|
||||
{"schema1", "schema2"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,128 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from superset.databases.filters import can_access_databases, DatabaseFilter
|
||||
from superset.extensions import security_manager
|
||||
|
||||
|
||||
def test_can_access_databases(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the `can_access_databases` function.
|
||||
"""
|
||||
mocker.patch.object(
|
||||
security_manager,
|
||||
"user_view_menu_names",
|
||||
side_effect=[
|
||||
{
|
||||
"[my_db].[examples].[public].[table1](id:1)",
|
||||
"[my_other_db].[examples].[public].[table1](id:2)",
|
||||
},
|
||||
{"[my_db].(id:42)", "[my_other_db].(id:43)"},
|
||||
{"[my_db].[examples]", "[my_db].[other]"},
|
||||
{
|
||||
"[my_db].[examples].[information_schema]",
|
||||
"[my_db].[other].[secret]",
|
||||
"[third_db].[schema]",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert can_access_databases("datasource_access") == {"my_db", "my_other_db"}
|
||||
assert can_access_databases("database_access") == {"my_db", "my_other_db"}
|
||||
assert can_access_databases("catalog_access") == {"my_db"}
|
||||
assert can_access_databases("schema_access") == {"my_db", "third_db"}
|
||||
|
||||
|
||||
def test_database_filter_full_db_access(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the `DatabaseFilter` class when the user has full database access.
|
||||
|
||||
In this case the query should be returned unmodified.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
current_app = mocker.patch("superset.databases.filters.current_app")
|
||||
current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False}
|
||||
mocker.patch.object(security_manager, "can_access_all_databases", return_value=True)
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
query = session.query(Database)
|
||||
|
||||
filter_ = DatabaseFilter("id", SQLAInterface(Database))
|
||||
filtered_query = filter_.apply(query, None)
|
||||
|
||||
assert filtered_query == query
|
||||
|
||||
|
||||
def test_database_filter(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the `DatabaseFilter` class with specific permissions.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
current_app = mocker.patch("superset.databases.filters.current_app")
|
||||
current_app.config = {"EXTRA_DYNAMIC_QUERY_FILTERS": False}
|
||||
mocker.patch.object(
|
||||
security_manager,
|
||||
"can_access_all_databases",
|
||||
return_value=False,
|
||||
)
|
||||
mocker.patch.object(
|
||||
security_manager,
|
||||
"user_view_menu_names",
|
||||
side_effect=[
|
||||
# return lists instead of sets to ensure order
|
||||
["[my_db].(id:42)", "[my_other_db].(id:43)"],
|
||||
["[my_db].[examples]", "[my_db].[other]"],
|
||||
[
|
||||
"[my_db].[examples].[information_schema]",
|
||||
"[my_db].[other].[secret]",
|
||||
"[third_db].[schema]",
|
||||
],
|
||||
[
|
||||
"[my_db].[examples].[public].[table1](id:1)",
|
||||
"[my_other_db].[examples].[public].[table1](id:2)",
|
||||
],
|
||||
],
|
||||
)
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
query = session.query(Database)
|
||||
|
||||
filter_ = DatabaseFilter("id", SQLAInterface(Database))
|
||||
filtered_query = filter_.apply(query, None)
|
||||
|
||||
compiled_query = filtered_query.statement.compile(
|
||||
engine,
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
space = " " # pre-commit removes trailing spaces...
|
||||
assert (
|
||||
str(compiled_query)
|
||||
== f"""SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk{space}
|
||||
FROM dbs{space}
|
||||
WHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')"""
|
||||
)
|
||||
|
|
@ -301,3 +301,13 @@ def test_extra_table_metadata(mocker: MockFixture) -> None:
|
|||
)
|
||||
|
||||
warnings.warn.assert_called()
|
||||
|
||||
|
||||
def test_get_default_catalog(mocker: MockFixture) -> None:
|
||||
"""
|
||||
Test the `get_default_catalog` method.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
database = mocker.MagicMock()
|
||||
assert BaseEngineSpec.get_default_catalog(database) is None
|
||||
|
|
|
|||
|
|
@ -175,3 +175,33 @@ SELECT * FROM some_table;
|
|||
str(excinfo.value)
|
||||
== "Users are not allowed to set a search path for security reasons."
|
||||
)
|
||||
|
||||
|
||||
def test_adjust_engine_params() -> None:
|
||||
"""
|
||||
Test `adjust_engine_params`.
|
||||
|
||||
The method can be used to adjust the catalog (database) dynamically.
|
||||
"""
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
|
||||
adjusted = PostgresEngineSpec.adjust_engine_params(
|
||||
make_url("postgresql://user:password@host:5432/dev"),
|
||||
{},
|
||||
catalog="prod",
|
||||
)
|
||||
assert adjusted == (make_url("postgresql://user:password@host:5432/prod"), {})
|
||||
|
||||
|
||||
def test_get_default_catalog() -> None:
|
||||
"""
|
||||
Test `get_default_catalog`.
|
||||
"""
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.models.core import Database
|
||||
|
||||
database = Database(
|
||||
database_name="postgres",
|
||||
sqlalchemy_uri="postgresql://user:password@host:5432/dev",
|
||||
)
|
||||
assert PostgresEngineSpec.get_default_catalog(database) == "dev"
|
||||
|
|
|
|||
|
|
@ -272,6 +272,7 @@ def test_query_no_access(mocker: MockFixture, client) -> None:
|
|||
from superset.models.sql_lab import Query
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.get_default_catalog.return_value = None
|
||||
database.get_default_schema_for_query.return_value = "public"
|
||||
mocker.patch(
|
||||
query_find_by_id,
|
||||
|
|
|
|||
|
|
@ -229,3 +229,63 @@ def test_get_prequeries(mocker: MockFixture) -> None:
|
|||
conn.cursor().execute.assert_has_calls(
|
||||
[mocker.call("set a=1"), mocker.call("set b=2")]
|
||||
)
|
||||
|
||||
|
||||
def test_catalog_cache() -> None:
|
||||
"""
|
||||
Test the catalog cache.
|
||||
"""
|
||||
database = Database(
|
||||
database_name="db",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
extra=json.dumps({"metadata_cache_timeout": {"catalog_cache_timeout": 10}}),
|
||||
)
|
||||
|
||||
assert database.catalog_cache_enabled
|
||||
assert database.catalog_cache_timeout == 10
|
||||
|
||||
|
||||
def test_get_default_catalog() -> None:
|
||||
"""
|
||||
Test the `get_default_catalog` method.
|
||||
"""
|
||||
database = Database(
|
||||
database_name="db",
|
||||
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
|
||||
)
|
||||
|
||||
assert database.get_default_catalog() == "examples"
|
||||
|
||||
|
||||
def test_get_default_schema(mocker: MockFixture) -> None:
|
||||
"""
|
||||
Test the `get_default_schema` method.
|
||||
"""
|
||||
database = Database(
|
||||
database_name="db",
|
||||
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
|
||||
)
|
||||
|
||||
get_inspector = mocker.patch.object(database, "get_inspector")
|
||||
with get_inspector() as inspector:
|
||||
inspector.default_schema_name = "public"
|
||||
|
||||
assert database.get_default_schema("examples") == "public"
|
||||
get_inspector.assert_called_with(catalog="examples")
|
||||
|
||||
|
||||
def test_get_all_catalog_names(mocker: MockFixture) -> None:
|
||||
"""
|
||||
Test the `get_all_catalog_names` method.
|
||||
"""
|
||||
database = Database(
|
||||
database_name="db",
|
||||
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
|
||||
)
|
||||
|
||||
get_inspector = mocker.patch.object(database, "get_inspector")
|
||||
with get_inspector() as inspector:
|
||||
inspector.bind.execute.return_value = [("examples",), ("other",)]
|
||||
|
||||
assert database.get_all_catalog_names(force=True) == {"examples", "other"}
|
||||
get_inspector.assert_called_with(ssh_tunnel=None)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,10 @@ from superset.connectors.sqla.models import Database, SqlaTable
|
|||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.extensions import appbuilder
|
||||
from superset.models.slice import Slice
|
||||
from superset.security.manager import query_context_modified, SupersetSecurityManager
|
||||
from superset.security.manager import (
|
||||
query_context_modified,
|
||||
SupersetSecurityManager,
|
||||
)
|
||||
from superset.sql_parse import Table
|
||||
from superset.superset_typing import AdhocMetric
|
||||
from superset.utils.core import override_user
|
||||
|
|
@ -208,6 +211,7 @@ def test_raise_for_access_query_default_schema(
|
|||
SqlaTable.query_datasources_by_name.return_value = []
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.get_default_catalog.return_value = None
|
||||
database.get_default_schema_for_query.return_value = "public"
|
||||
query = mocker.MagicMock()
|
||||
query.database = database
|
||||
|
|
@ -262,6 +266,7 @@ def test_raise_for_access_jinja_sql(mocker: MockFixture, app_context: None) -> N
|
|||
SqlaTable.query_datasources_by_name.return_value = []
|
||||
|
||||
database = mocker.MagicMock()
|
||||
database.get_default_catalog.return_value = None
|
||||
database.get_default_schema_for_query.return_value = "public"
|
||||
query = mocker.MagicMock()
|
||||
query.database = database
|
||||
|
|
@ -531,3 +536,28 @@ def test_query_context_modified_mixed_chart(mocker: MockFixture) -> None:
|
|||
}
|
||||
query_context.queries = [QueryObject(metrics=requested_metrics)] # type: ignore
|
||||
assert not query_context_modified(query_context)
|
||||
|
||||
|
||||
def test_get_catalog_perm() -> None:
|
||||
"""
|
||||
Test the `get_catalog_perm` method.
|
||||
"""
|
||||
sm = SupersetSecurityManager(appbuilder)
|
||||
|
||||
assert sm.get_catalog_perm("my_db", None) is None
|
||||
assert sm.get_catalog_perm("my_db", "my_catalog") == "[my_db].[my_catalog]"
|
||||
|
||||
|
||||
def test_get_schema_perm() -> None:
|
||||
"""
|
||||
Test the `get_schema_perm` method.
|
||||
"""
|
||||
sm = SupersetSecurityManager(appbuilder)
|
||||
|
||||
assert sm.get_schema_perm("my_db", None, "my_schema") == "[my_db].[my_schema]"
|
||||
assert (
|
||||
sm.get_schema_perm("my_db", "my_catalog", "my_schema")
|
||||
== "[my_db].[my_catalog].[my_schema]"
|
||||
)
|
||||
assert sm.get_schema_perm("my_db", None, None) is None
|
||||
assert sm.get_schema_perm("my_db", "my_catalog", None) is None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from superset.utils.filters import get_dataset_access_filters
|
||||
|
||||
|
||||
def test_get_dataset_access_filters(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the `get_dataset_access_filters` function.
|
||||
"""
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.extensions import security_manager
|
||||
|
||||
mocker.patch.object(
|
||||
security_manager,
|
||||
"get_accessible_databases",
|
||||
return_value=[1, 3],
|
||||
)
|
||||
mocker.patch.object(
|
||||
security_manager,
|
||||
"user_view_menu_names",
|
||||
side_effect=[
|
||||
{"[db].[catalog1].[schema1].[table1](id:1)"},
|
||||
{"[db].[catalog1].[schema2]"},
|
||||
{"[db].[catalog2]"},
|
||||
],
|
||||
)
|
||||
|
||||
clause = get_dataset_access_filters(SqlaTable)
|
||||
engine = create_engine("sqlite://")
|
||||
compiled_query = clause.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
assert str(compiled_query) == (
|
||||
"dbs.id IN (1, 3) "
|
||||
"OR tables.perm IN ('[db].[catalog1].[schema1].[table1](id:1)') "
|
||||
"OR tables.catalog_perm IN ('[db].[catalog2]') OR "
|
||||
"tables.schema_perm IN ('[db].[catalog1].[schema2]')"
|
||||
)
|
||||
|
|
@ -29,6 +29,7 @@ from superset.utils.core import (
|
|||
DateColumn,
|
||||
generic_find_constraint_name,
|
||||
generic_find_fk_constraint_name,
|
||||
get_datasource_full_name,
|
||||
is_test,
|
||||
normalize_dttm_col,
|
||||
parse_boolean_string,
|
||||
|
|
@ -369,3 +370,29 @@ def test_generic_find_fk_constraint_none_exist():
|
|||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_datasource_full_name():
|
||||
"""
|
||||
Test the `get_datasource_full_name` function.
|
||||
|
||||
This is used to build permissions, so it doesn't really return the datasource full
|
||||
name. Instead, it returns a fully qualified table name that includes the database
|
||||
name and schema, with each part wrapped in square brackets.
|
||||
"""
|
||||
assert (
|
||||
get_datasource_full_name("db", "table", "catalog", "schema")
|
||||
== "[db].[catalog].[schema].[table]"
|
||||
)
|
||||
|
||||
assert get_datasource_full_name("db", "table", None, None) == "[db].[table]"
|
||||
|
||||
assert (
|
||||
get_datasource_full_name("db", "table", None, "schema")
|
||||
== "[db].[schema].[table]"
|
||||
)
|
||||
|
||||
assert (
|
||||
get_datasource_full_name("db", "table", "catalog", None)
|
||||
== "[db].[catalog].[table]"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.views.database.mixins import DatabaseMixin
|
||||
|
||||
|
||||
def test_pre_add_update_with_catalog(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the `_pre_add_update` method on a DB with catalog support.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
add_permission_view_menu = mocker.patch(
|
||||
"superset.views.database.mixins.security_manager.add_permission_view_menu"
|
||||
)
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
id=42,
|
||||
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
|
||||
)
|
||||
mocker.patch.object(
|
||||
database,
|
||||
"get_all_catalog_names",
|
||||
return_value=["examples", "other"],
|
||||
)
|
||||
mocker.patch.object(
|
||||
database,
|
||||
"get_all_schema_names",
|
||||
side_effect=[
|
||||
["public", "information_schema"],
|
||||
["secret"],
|
||||
],
|
||||
)
|
||||
|
||||
mixin = DatabaseMixin()
|
||||
mixin._pre_add_update(database)
|
||||
|
||||
add_permission_view_menu.assert_has_calls(
|
||||
[
|
||||
mocker.call("database_access", "[my_db].(id:42)"),
|
||||
mocker.call("catalog_access", "[my_db].[examples]"),
|
||||
mocker.call("catalog_access", "[my_db].[other]"),
|
||||
mocker.call("schema_access", "[my_db].[examples].[public]"),
|
||||
mocker.call("schema_access", "[my_db].[examples].[information_schema]"),
|
||||
mocker.call("schema_access", "[my_db].[other].[secret]"),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
Loading…
Reference in New Issue