diff --git a/superset/assets/cypress.json b/superset/assets/cypress.json index 09ee3a7d9..715717b72 100644 --- a/superset/assets/cypress.json +++ b/superset/assets/cypress.json @@ -1,7 +1,8 @@ { "baseUrl": "http://localhost:8081", "chromeWebSecurity": false, - "defaultCommandTimeout": 10000, + "defaultCommandTimeout": 20000, + "requestTimeout": 20000, "ignoreTestFiles": ["**/!(*.test.js)"], "projectId": "fbf96q", "video": false, diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index e9ea9d890..a95e251db 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -61,6 +61,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): cache_timeout = Column(Integer) params = Column(String(1000)) perm = Column(String(1000)) + schema_perm = Column(String(1000)) sql: Optional[str] = None owners: List[User] diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py index 9ce11802f..736b7e835 100644 --- a/superset/connectors/connector_registry.py +++ b/superset/connectors/connector_registry.py @@ -18,6 +18,7 @@ from collections import OrderedDict from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING +from sqlalchemy import or_ from sqlalchemy.orm import Session, subqueryload if TYPE_CHECKING: @@ -75,13 +76,23 @@ class ConnectorRegistry(object): @classmethod def query_datasources_by_permissions( - cls, session: Session, database: "Database", permissions: Set[str] + cls, + session: Session, + database: "Database", + permissions: Set[str], + schema_perms: Set[str], ) -> List["BaseDatasource"]: + # TODO(bogdan): add unit test datasource_class = ConnectorRegistry.sources[database.type] return ( session.query(datasource_class) .filter_by(database_id=database.id) - .filter(datasource_class.perm.in_(permissions)) + .filter( + or_( + datasource_class.perm.in_(permissions), + datasource_class.schema_perm.in_(schema_perms), + ) + ) .all() ) @@ -111,5 +122,5 @@ class ConnectorRegistry(object): ) -> List["BaseDatasource"]: datasource_class = ConnectorRegistry.sources[database.type] return datasource_class.query_datasources_by_name( - session, database, datasource_name, schema=None + session, database, datasource_name, schema=schema ) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 218fd1ce1..ec4001248 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -292,6 +292,10 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): return self.verbose_name or self.cluster_name +sa.event.listen(DruidCluster, "after_insert", security_manager.set_perm) +sa.event.listen(DruidCluster, "after_update", security_manager.set_perm) + + class DruidColumn(Model, BaseColumn): """ORM model for storing Druid datasource column metadata""" @@ -529,8 +533,7 @@ class DruidDatasource(Model, BaseDatasource): else: return None - @property - def schema_perm(self) -> Optional[str]: + def get_schema_perm(self) -> Optional[str]: """Returns schema permission if present, cluster one otherwise.""" return security_manager.get_schema_perm(self.cluster, self.schema) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 93b1f4418..5dd8c8f2f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -459,8 +459,7 @@ class SqlaTable(Model, BaseDatasource): anchor = f'{name}' return Markup(anchor) - @property - def schema_perm(self) -> Optional[str]: + def get_schema_perm(self) -> Optional[str]: """Returns schema permission if present, database one otherwise.""" return security_manager.get_schema_perm(self.database, self.schema) diff --git a/superset/migrations/versions/5afa9079866a_serialize_schema_permissions_py.py b/superset/migrations/versions/5afa9079866a_serialize_schema_permissions_py.py new file mode 100644 index 000000000..5788102b0 --- /dev/null +++ b/superset/migrations/versions/5afa9079866a_serialize_schema_permissions_py.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""serialize_schema_permissions.py + +Revision ID: 5afa9079866a +Revises: db4b49eb0782 +Create Date: 2019-09-11 21:49:00.608346 + +""" + + +# revision identifiers, used by Alembic. +from alembic import op +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +from superset import db + +revision = "5afa9079866a" +down_revision = "db4b49eb0782" + +Base = declarative_base() + + +class Sqlatable(Base): + __tablename__ = "tables" + + id = Column(Integer, primary_key=True) + perm = Column(String(1000)) + schema_perm = Column(String(1000)) + schema = Column(String(255)) + database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False) + database = relationship("Database", foreign_keys=[database_id]) + + +class Slice(Base): + __tablename__ = "slices" + + id = Column(Integer, primary_key=True) + datasource_id = Column(Integer) + datasource_type = Column(String(200)) + schema_perm = Column(String(1000)) + + +class Database(Base): + __tablename__ = "dbs" + + id = Column(Integer, primary_key=True) + database_name = Column(String(250)) + verbose_name = Column(String(250), unique=True) + + +def upgrade(): + op.add_column( + "datasources", Column("schema_perm", String(length=1000), nullable=True) + ) + op.add_column("slices", Column("schema_perm", String(length=1000), nullable=True)) + op.add_column("tables", Column("schema_perm", String(length=1000), nullable=True)) + + bind = op.get_bind() + session = db.Session(bind=bind) + for t in session.query(Sqlatable).all(): + db_name = ( + t.database.verbose_name + if t.database.verbose_name + else t.database.database_name + ) + if t.schema: + t.schema_perm = f"[{db_name}].[{t.schema}]" + table_slices = ( + session.query(Slice) + .filter_by(datasource_type="table") + .filter_by(datasource_id=t.id) + .all() + ) + for s in table_slices: + s.schema_perm = t.schema_perm + session.commit() + + +def downgrade(): + op.drop_column("tables", "schema_perm") + op.drop_column("datasources", "schema_perm") + op.drop_column("slices", "schema_perm") diff --git a/superset/models/core.py b/superset/models/core.py index 26329f427..4136d7433 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -85,6 +85,7 @@ def set_related_perm(mapper, connection, target): ds = db.session.query(src_class).filter_by(id=int(id_)).first() if ds: target.perm = ds.perm + target.schema_perm = ds.schema_perm def copy_dashboard(mapper, connection, target): @@ -172,6 +173,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin): description = Column(Text) cache_timeout = Column(Integer) perm = Column(String(1000)) + schema_perm = Column(String(1000)) owners = relationship(security_manager.user_model, secondary=slice_user) export_fields = [ diff --git a/superset/security.py b/superset/security.py index d8dd5edcd..8b436209e 100644 --- a/superset/security.py +++ b/superset/security.py @@ -17,12 +17,16 @@ # pylint: disable=C,R,W """A set of constants and methods to manage permissions and security""" import logging -from typing import Callable, List, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, List, Optional, Set, Tuple, TYPE_CHECKING, Union from flask import current_app, g from flask_appbuilder import Model from flask_appbuilder.security.sqla import models as ab_models from flask_appbuilder.security.sqla.manager import SecurityManager +from flask_appbuilder.security.sqla.models import ( + assoc_permissionview_role, + assoc_user_role, +) from flask_appbuilder.security.views import ( PermissionModelView, PermissionViewModelView, @@ -149,6 +153,12 @@ class SupersetSecurityManager(SecurityManager): return None + def unpack_schema_perm(self, schema_permission: str) -> Tuple[str, str]: + # [database_name].[schema_name] + schema_name = schema_permission.split(".")[1][1:-1] + database_name = schema_permission.split(".")[0][1:-1] + return database_name, schema_name + def can_access(self, permission_name: str, view_name: str) -> bool: """ Return True if the user can access the FAB permission/view, False @@ -378,19 +388,48 @@ class SupersetSecurityManager(SecurityManager): if not self._datasource_access_by_fullname(database, t, schema) ] - def _user_datasource_perms(self) -> Set[str]: - """ - Return the set of FAB permission view-menu names the user can access. + def get_public_role(self) -> Optional[Any]: # Optional[self.role_model] + from superset import conf - :returns: The set of FAB permission view-menu names - """ + if not conf.get("PUBLIC_ROLE_LIKE_GAMMA", False): + return None - datasource_perms = set() - for r in g.user.roles: - for perm in r.permissions: - if perm.permission and "datasource_access" == perm.permission.name: - datasource_perms.add(perm.view_menu.name) - return datasource_perms + from superset import db + + return db.session.query(self.role_model).filter_by(name="Public").first() + + def user_view_menu_names(self, permission_name: str) -> Set[str]: + from superset import db + + base_query = ( + db.session.query(self.viewmenu_model.name) + .join(self.permissionview_model) + .join(self.permission_model) + .join(assoc_permissionview_role) + .join(self.role_model) + ) + + if not g.user.is_anonymous: + # filter by user id + view_menu_names = ( + base_query.join(assoc_user_role) + .join(self.user_model) + .filter(self.user_model.id == g.user.id) + .filter(self.permission_model.name == permission_name) + ).all() + return set([s.name for s in view_menu_names]) + + # Properly treat anonymous user + public_role = self.get_public_role() + if public_role: + # filter by public role + view_menu_names = ( + base_query.filter(self.role_model.id == public_role.id).filter( + self.permission_model.name == permission_name + ) + ).all() + return set([s.name for s in view_menu_names]) + return set() def schemas_accessible_by_user( self, database: "Database", schemas: List[str], hierarchical: bool = True @@ -412,23 +451,27 @@ class SupersetSecurityManager(SecurityManager): ): return schemas - subset = set() - for schema in schemas: - schema_perm = self.get_schema_perm(database, schema) - if schema_perm and self.can_access("schema_access", schema_perm): - subset.add(schema) + # schema_access + accessible_schemas = { + self.unpack_schema_perm(s)[1] + for s in self.user_view_menu_names("schema_access") + if s.startswith(f"[{database}].") + } - perms = self._user_datasource_perms() + # datasource_access + perms = self.user_view_menu_names("datasource_access") if perms: tables = ( - db.session.query(SqlaTable) - .filter(SqlaTable.perm.in_(perms), SqlaTable.database_id == database.id) - .all() + db.session.query(SqlaTable.schema) + .filter(SqlaTable.database_id == database.id) + .filter(SqlaTable.schema.isnot(None)) + .filter(SqlaTable.schema != "") + .filter(or_(SqlaTable.perm.in_(perms))) + .distinct() ) - for t in tables: - if t.schema: - subset.add(t.schema) - return sorted(list(subset)) + accessible_schemas.update([t.schema for t in tables]) + + return [s for s in schemas if s in accessible_schemas] def get_datasources_accessible_by_user( self, @@ -455,9 +498,10 @@ class SupersetSecurityManager(SecurityManager): if schema_perm and self.can_access("schema_access", schema_perm): return datasource_names - user_perms = self._user_datasource_perms() + user_perms = self.user_view_menu_names("datasource_access") + schema_perms = self.user_view_menu_names("schema_access") user_datasources = ConnectorRegistry.query_datasources_by_permissions( - db.session, database, user_perms + db.session, database, user_perms, schema_perms ) if schema: names = {d.table_name for d in user_datasources if d.schema == schema} @@ -525,7 +569,7 @@ class SupersetSecurityManager(SecurityManager): datasources = ConnectorRegistry.get_all_datasources(db.session) for datasource in datasources: merge_pv("datasource_access", datasource.get_perm()) - merge_pv("schema_access", datasource.schema_perm) + merge_pv("schema_access", datasource.get_schema_perm()) logging.info("Creating missing database permissions.") databases = db.session.query(models.Database).all() @@ -737,49 +781,69 @@ class SupersetSecurityManager(SecurityManager): :param connection: The DB-API connection :param target: The mapped instance being persisted """ - + link_table = target.__table__ # pylint: disable=no-member if target.perm != target.get_perm(): - link_table = target.__table__ connection.execute( link_table.update() .where(link_table.c.id == target.id) .values(perm=target.get_perm()) ) - # add to view menu if not already exists - permission_name = "datasource_access" - view_menu_name = target.get_perm() - permission = self.find_permission(permission_name) - view_menu = self.find_view_menu(view_menu_name) - pv = None - - if not permission: - permission_table = ( - self.permission_model.__table__ # pylint: disable=no-member - ) - connection.execute(permission_table.insert().values(name=permission_name)) - permission = self.find_permission(permission_name) - if not view_menu: - view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member - connection.execute(view_menu_table.insert().values(name=view_menu_name)) - view_menu = self.find_view_menu(view_menu_name) - - if permission and view_menu: - pv = ( - self.get_session.query(self.permissionview_model) - .filter_by(permission=permission, view_menu=view_menu) - .first() - ) - if not pv and permission and view_menu: - permission_view_table = ( - self.permissionview_model.__table__ # pylint: disable=no-member - ) + if ( + hasattr(target, "schema_perm") + and target.schema_perm != target.get_schema_perm() + ): connection.execute( - permission_view_table.insert().values( - permission_id=permission.id, view_menu_id=view_menu.id - ) + link_table.update() + .where(link_table.c.id == target.id) + .values(schema_perm=target.get_schema_perm()) ) + pvm_names = [] + if target.__tablename__ in {"dbs", "clusters"}: + pvm_names.append(("database_access", target.get_perm())) + else: + pvm_names.append(("datasource_access", target.get_perm())) + if target.schema: + pvm_names.append(("schema_access", target.get_schema_perm())) + + # TODO(bogdan): modify slice permissions as well. + for permission_name, view_menu_name in pvm_names: + permission = self.find_permission(permission_name) + view_menu = self.find_view_menu(view_menu_name) + pv = None + + if not permission: + permission_table = ( + self.permission_model.__table__ # pylint: disable=no-member + ) + connection.execute( + permission_table.insert().values(name=permission_name) + ) + permission = self.find_permission(permission_name) + if not view_menu: + view_menu_table = ( + self.viewmenu_model.__table__ # pylint: disable=no-member + ) + connection.execute(view_menu_table.insert().values(name=view_menu_name)) + view_menu = self.find_view_menu(view_menu_name) + + if permission and view_menu: + pv = ( + self.get_session.query(self.permissionview_model) + .filter_by(permission=permission, view_menu=view_menu) + .first() + ) + if not pv and permission and view_menu: + permission_view_table = ( + self.permissionview_model.__table__ # pylint: disable=no-member + ) + connection.execute( + permission_view_table.insert().values( + permission_id=permission.id, view_menu_id=view_menu.id + ) + ) + def assert_datasource_permission(self, datasource: "BaseDatasource") -> None: """ Assert the the user has permission to access the Superset datasource. diff --git a/superset/views/base.py b/superset/views/base.py index 739de1c65..ec8ff69b2 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -31,6 +31,7 @@ from flask_appbuilder.models.sqla.filters import BaseFilter from flask_appbuilder.widgets import ListWidget from flask_babel import get_locale, gettext as __, lazy_gettext as _ from flask_wtf.form import FlaskForm +from sqlalchemy import or_ from werkzeug.exceptions import HTTPException from wtforms.fields.core import Field, UnboundField @@ -348,53 +349,18 @@ class DeleteMixin(object): return redirect(self.get_redirect()) -class SupersetFilter(BaseFilter): - - """Add utility function to make BaseFilter easy and fast - - These utility function exist in the SecurityManager, but would do - a database round trip at every check. Here we cache the role objects - to be able to make multiple checks but query the db only once - """ - - def get_user_roles(self): - return get_user_roles() - - def get_all_permissions(self): - """Returns a set of tuples with the perm name and view menu name""" - perms = set() - for role in self.get_user_roles(): - for perm_view in role.permissions: - t = (perm_view.permission.name, perm_view.view_menu.name) - perms.add(t) - return perms - - def has_role(self, role_name_or_list): - """Whether the user has this role name""" - if not isinstance(role_name_or_list, list): - role_name_or_list = [role_name_or_list] - return any([r.name in role_name_or_list for r in self.get_user_roles()]) - - def has_perm(self, permission_name, view_menu_name): - """Whether the user has this perm""" - return (permission_name, view_menu_name) in self.get_all_permissions() - - def get_view_menus(self, permission_name): - """Returns the details of view_menus for a perm name""" - vm = set() - for perm_name, vm_name in self.get_all_permissions(): - if perm_name == permission_name: - vm.add(vm_name) - return vm - - -class DatasourceFilter(SupersetFilter): - def apply(self, query, func): +class DatasourceFilter(BaseFilter): + def apply(self, query, func): # noqa if security_manager.all_datasource_access(): return query - perms = self.get_view_menus("datasource_access") - # TODO(bogdan): add `schema_access` support here - return query.filter(self.model.perm.in_(perms)) + datasource_perms = security_manager.user_view_menu_names("datasource_access") + schema_perms = security_manager.user_view_menu_names("schema_access") + return query.filter( + or_( + self.model.perm.in_(datasource_perms), + self.model.schema_perm.in_(schema_perms), + ) + ) class CsvResponse(Response): diff --git a/superset/views/core.py b/superset/views/core.py index fd7a6f68f..958f6421d 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -88,6 +88,7 @@ from superset.utils.decorators import etag_cache, stats_timing from .base import ( api, + BaseFilter, BaseSupersetView, check_ownership, CsvResponse, @@ -99,7 +100,6 @@ from .base import ( handle_api_exception, json_error_response, json_success, - SupersetFilter, SupersetModelView, ) from .database import api as database_api, views as in_views @@ -243,16 +243,18 @@ def _deserialize_results_payload( return json.loads(payload) # type: ignore -class SliceFilter(SupersetFilter): - def apply(self, query, func): +class SliceFilter(BaseFilter): + def apply(self, query, func): # noqa if security_manager.all_datasource_access(): return query - perms = self.get_view_menus("datasource_access") - # TODO(bogdan): add `schema_access` support here - return query.filter(self.model.perm.in_(perms)) + perms = security_manager.user_view_menu_names("datasource_access") + schema_perms = security_manager.user_view_menu_names("schema_access") + return query.filter( + or_(self.model.perm.in_(perms), self.model.schema_perm.in_(schema_perms)) + ) -class DashboardFilter(SupersetFilter): +class DashboardFilter(BaseFilter): """ List dashboards with the following criteria: 1. Those which the user owns @@ -270,19 +272,24 @@ class DashboardFilter(SupersetFilter): Slice = models.Slice Favorites = models.FavStar - user_roles = [role.name.lower() for role in list(self.get_user_roles())] + user_roles = [role.name.lower() for role in list(get_user_roles())] if "admin" in user_roles: return query - datasource_perms = self.get_view_menus("datasource_access") + datasource_perms = security_manager.user_view_menu_names("datasource_access") + schema_perms = security_manager.user_view_menu_names("schema_access") all_datasource_access = security_manager.all_datasource_access() published_dash_query = ( db.session.query(Dash.id) .join(Dash.slices) .filter( and_( - Dash.published == True, - or_(Slice.perm.in_(datasource_perms), all_datasource_access), + Dash.published == True, # noqa + or_( + Slice.perm.in_(datasource_perms), + Slice.schema_perm.in_(schema_perms), + all_datasource_access, + ), ) ) ) diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py index ff8dc1e9e..f9cebe9f9 100644 --- a/superset/views/database/mixins.py +++ b/superset/views/database/mixins.py @@ -18,20 +18,37 @@ import inspect from flask import Markup from flask_babel import lazy_gettext as _ -from sqlalchemy import MetaData +from sqlalchemy import MetaData, or_ from superset import security_manager from superset.exceptions import SupersetException from superset.utils import core as utils -from superset.views.base import SupersetFilter +from superset.views.base import BaseFilter -class DatabaseFilter(SupersetFilter): - def apply(self, query, value): +class DatabaseFilter(BaseFilter): + # TODO(bogdan): consider caching. + def schema_access_databases(self): # noqa pylint: disable=no-self-use + found_databases = set() + for vm in security_manager.user_view_menu_names("schema_access"): + database_name, _ = security_manager.unpack_schema_perm(vm) + found_databases.add(database_name) + return found_databases + + def apply( + self, query, func + ): # noqa pylint: disable=unused-argument,arguments-differ if security_manager.all_database_access(): return query - perms = self.get_view_menus("database_access") - return query.filter(self.model.perm.in_(perms)) + database_perms = security_manager.user_view_menu_names("database_access") + # TODO(bogdan): consider adding datasource access here as well. + schema_access_databases = self.schema_access_databases() + return query.filter( + or_( + self.model.perm.in_(database_perms), + self.model.database_name.in_(schema_access_databases), + ) + ) class DatabaseMixin: diff --git a/superset/views/sql_lab.py b/superset/views/sql_lab.py index fe072956f..29d609443 100644 --- a/superset/views/sql_lab.py +++ b/superset/views/sql_lab.py @@ -30,15 +30,15 @@ from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState from superset.utils import core as utils from .base import ( + BaseFilter, BaseSupersetView, DeleteMixin, json_success, - SupersetFilter, SupersetModelView, ) -class QueryFilter(SupersetFilter): +class QueryFilter(BaseFilter): def apply(self, query: BaseQuery, func: Callable) -> BaseQuery: """ Filter queries to only those owned by current user if diff --git a/tests/security_tests.py b/tests/security_tests.py index 47c2b1405..f5fbb1c80 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -18,8 +18,14 @@ import inspect import unittest from unittest.mock import Mock, patch -from superset import app, appbuilder, security_manager, viz +import prison + +from superset import app, appbuilder, db, security_manager, viz +from superset.connectors.druid.models import DruidCluster, DruidDatasource +from superset.connectors.sqla.models import SqlaTable from superset.exceptions import SupersetSecurityException +from superset.models.core import Database, Slice +from superset.utils.core import get_example_database from .base_tests import SupersetTestCase @@ -31,8 +37,471 @@ def get_perm_tuples(role_name): return perm_set +SCHEMA_ACCESS_ROLE = "schema_access_role" + + +def create_schema_perm(view_menu_name: str) -> None: + permission = "schema_access" + security_manager.add_permission_view_menu(permission, view_menu_name) + perm_view = security_manager.find_permission_view_menu(permission, view_menu_name) + security_manager.add_permission_role( + security_manager.find_role(SCHEMA_ACCESS_ROLE), perm_view + ) + return None + + +def delete_schema_perm(view_menu_name: str) -> None: + pv = security_manager.find_permission_view_menu("schema_access", "[examples].[2]") + security_manager.del_permission_role( + security_manager.find_role(SCHEMA_ACCESS_ROLE), pv + ) + security_manager.del_permission_view_menu("schema_access", "[examples].[2]") + return None + + class RolePermissionTests(SupersetTestCase): - """Testing export import functionality for dashboards""" + """Testing export role permissions.""" + + def setUp(self): + session = db.session + security_manager.add_role(SCHEMA_ACCESS_ROLE) + session.commit() + + ds = ( + db.session.query(SqlaTable) + .filter_by(table_name="wb_health_population") + .first() + ) + ds.schema = "temp_schema" + ds.schema_perm = ds.get_schema_perm() + + ds_slices = ( + session.query(Slice) + .filter_by(datasource_type="table") + .filter_by(datasource_id=ds.id) + .all() + ) + for s in ds_slices: + s.schema_perm = ds.schema_perm + create_schema_perm("[examples].[temp_schema]") + gamma_user = security_manager.find_user(username="gamma") + gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE)) + session.commit() + + def tearDown(self): + session = db.session + ds = ( + session.query(SqlaTable) + .filter_by(table_name="wb_health_population") + .first() + ) + schema_perm = ds.schema_perm + ds.schema = None + ds.schema_perm = None + ds_slices = ( + session.query(Slice) + .filter_by(datasource_type="table") + .filter_by(datasource_id=ds.id) + .all() + ) + for s in ds_slices: + s.schema_perm = None + + delete_schema_perm(schema_perm) + session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE)) + session.commit() + + def test_set_perm_sqla_table(self): + session = db.session + table = SqlaTable( + schema="tmp_schema", + table_name="tmp_perm_table", + database=get_example_database(), + ) + session.add(table) + session.commit() + + stored_table = ( + session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one() + ) + self.assertEquals( + stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})" + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", stored_table.perm + ) + ) + self.assertEquals(stored_table.schema_perm, "[examples].[tmp_schema]") + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "schema_access", stored_table.schema_perm + ) + ) + + # table name change + stored_table.table_name = "tmp_perm_table_v2" + session.commit() + stored_table = ( + session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + ) + self.assertEquals( + stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})" + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", stored_table.perm + ) + ) + # no changes in schema + self.assertEquals(stored_table.schema_perm, "[examples].[tmp_schema]") + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "schema_access", stored_table.schema_perm + ) + ) + + # schema name change + stored_table.schema = "tmp_schema_v2" + session.commit() + stored_table = ( + session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + ) + self.assertEquals( + stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})" + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", stored_table.perm + ) + ) + # no changes in schema + self.assertEquals(stored_table.schema_perm, "[examples].[tmp_schema_v2]") + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "schema_access", stored_table.schema_perm + ) + ) + + # database change + new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db") + session.add(new_db) + stored_table.database = ( + session.query(Database).filter_by(database_name="tmp_db").one() + ) + session.commit() + stored_table = ( + session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + ) + self.assertEquals( + stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})" + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", stored_table.perm + ) + ) + # no changes in schema + self.assertEquals(stored_table.schema_perm, "[tmp_db].[tmp_schema_v2]") + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "schema_access", stored_table.schema_perm + ) + ) + + # no schema + stored_table.schema = None + session.commit() + stored_table = ( + session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one() + ) + self.assertEquals( + stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})" + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", stored_table.perm + ) + ) + self.assertIsNone(stored_table.schema_perm) + + session.delete(new_db) + session.delete(stored_table) + session.commit() + + def test_set_perm_druid_datasource(self): + session = db.session + druid_cluster = ( + session.query(DruidCluster).filter_by(cluster_name="druid_test").one() + ) + datasource = DruidDatasource( + datasource_name="tmp_datasource", + cluster=druid_cluster, + cluster_name="druid_test", + ) + session.add(datasource) + session.commit() + + # store without a schema + stored_datasource = ( + session.query(DruidDatasource) + .filter_by(datasource_name="tmp_datasource") + .one() + ) + self.assertEquals( + stored_datasource.perm, + f"[druid_test].[tmp_datasource](id:{stored_datasource.id})", + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", stored_datasource.perm + ) + ) + self.assertIsNone(stored_datasource.schema_perm) + + # store with a schema + stored_datasource.datasource_name = "tmp_schema.tmp_datasource" + session.commit() + self.assertEquals( + stored_datasource.perm, + f"[druid_test].[tmp_schema.tmp_datasource](id:{stored_datasource.id})", + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "datasource_access", stored_datasource.perm + ) + ) + self.assertIsNotNone(stored_datasource.schema_perm, "[druid_test].[tmp_schema]") + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "schema_access", stored_datasource.schema_perm + ) + ) + + session.delete(stored_datasource) + session.commit() + + def test_set_perm_druid_cluster(self): + session = db.session + cluster = DruidCluster(cluster_name="tmp_druid_cluster") + session.add(cluster) + + stored_cluster = ( + session.query(DruidCluster) + .filter_by(cluster_name="tmp_druid_cluster") + .one() + ) + self.assertEquals( + stored_cluster.perm, f"[tmp_druid_cluster].(id:{stored_cluster.id})" + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", stored_cluster.perm + ) + ) + + stored_cluster.cluster_name = "tmp_druid_cluster2" + session.commit() + self.assertEquals( + stored_cluster.perm, f"[tmp_druid_cluster2].(id:{stored_cluster.id})" + ) + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", stored_cluster.perm + ) + ) + + session.delete(stored_cluster) + session.commit() + + def test_set_perm_database(self): + session = db.session + database = Database(database_name="tmp_database") + session.add(database) + + stored_db = ( + session.query(Database).filter_by(database_name="tmp_database").one() + ) + self.assertEquals(stored_db.perm, f"[tmp_database].(id:{stored_db.id})") + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", stored_db.perm + ) + ) + + stored_db.database_name = "tmp_database2" + session.commit() + stored_db = ( + session.query(Database).filter_by(database_name="tmp_database2").one() + ) + self.assertEquals(stored_db.perm, f"[tmp_database2].(id:{stored_db.id})") + self.assertIsNotNone( + security_manager.find_permission_view_menu( + "database_access", stored_db.perm + ) + ) + + session.delete(stored_db) + session.commit() + + def test_set_perm_slice(self): + session = db.session + database = Database(database_name="tmp_database") + table = SqlaTable(table_name="tmp_perm_table", database=database) + session.add(database) + session.add(table) + session.commit() + + # no schema permission + slice = Slice( + datasource_id=table.id, + datasource_type="table", + datasource_name="tmp_perm_table", + slice_name="slice_name", + ) + session.add(slice) + session.commit() + + slice = session.query(Slice).filter_by(slice_name="slice_name").one() + self.assertEquals(slice.perm, table.perm) + self.assertEquals(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") + self.assertEquals(slice.schema_perm, table.schema_perm) + self.assertIsNone(slice.schema_perm) + + table.schema = "tmp_perm_schema" + table.table_name = "tmp_perm_table_v2" + session.commit() + # TODO(bogdan): modify slice permissions on the table update. + self.assertNotEquals(slice.perm, table.perm) + self.assertEquals(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})") + self.assertEquals( + table.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" + ) + # TODO(bogdan): modify slice schema permissions on the table update. + self.assertNotEquals(slice.schema_perm, table.schema_perm) + self.assertIsNone(slice.schema_perm) + + # updating slice refreshes the permissions + slice.slice_name = "slice_name_v2" + session.commit() + self.assertEquals(slice.perm, table.perm) + self.assertEquals( + slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})" + ) + self.assertEquals(slice.schema_perm, table.schema_perm) + self.assertEquals(slice.schema_perm, "[tmp_database].[tmp_perm_schema]") + + session.delete(slice) + session.delete(table) + session.delete(database) + + session.commit() + + # TODO test slice permission + + @patch("superset.security.g") + def test_schemas_accessible_by_user_admin(self, mock_g): + mock_g.user = security_manager.find_user("admin") + with self.client.application.test_request_context(): + database = get_example_database() + schemas = security_manager.schemas_accessible_by_user( + database, ["1", "2", "3"] + ) + self.assertEquals(schemas, ["1", "2", "3"]) # no changes + + @patch("superset.security.g") + def test_schemas_accessible_by_user_schema_access(self, mock_g): + # User has schema access to the schema 1 + create_schema_perm("[examples].[1]") + mock_g.user = security_manager.find_user("gamma") + with self.client.application.test_request_context(): + database = get_example_database() + schemas = security_manager.schemas_accessible_by_user( + database, ["1", "2", "3"] + ) + # temp_schema is not passed in the params + self.assertEquals(schemas, ["1"]) + delete_schema_perm("[examples].[1]") + + @patch("superset.security.g") + def test_schemas_accessible_by_user_datasource_access(self, mock_g): + # User has schema access to the datasource temp_schema.wb_health_population in examples DB. + mock_g.user = security_manager.find_user("gamma") + with self.client.application.test_request_context(): + database = get_example_database() + schemas = security_manager.schemas_accessible_by_user( + database, ["temp_schema", "2", "3"] + ) + self.assertEquals(schemas, ["temp_schema"]) + + @patch("superset.security.g") + def test_schemas_accessible_by_user_datasource_and_schema_access(self, mock_g): + # User has schema access to the datasource temp_schema.wb_health_population in examples DB. + create_schema_perm("[examples].[2]") + mock_g.user = security_manager.find_user("gamma") + with self.client.application.test_request_context(): + database = get_example_database() + schemas = security_manager.schemas_accessible_by_user( + database, ["temp_schema", "2", "3"] + ) + self.assertEquals(schemas, ["temp_schema", "2"]) + vm = security_manager.find_permission_view_menu( + "schema_access", "[examples].[2]" + ) + self.assertIsNotNone(vm) + delete_schema_perm("[examples].[2]") + + def test_gamma_user_schema_access_to_dashboards(self): + self.login(username="gamma") + data = str(self.client.get("dashboard/list/").data) + self.assertIn("/superset/dashboard/world_health/", data) + self.assertNotIn("/superset/dashboard/births/", data) + + def test_gamma_user_schema_access_to_tables(self): + self.login(username="gamma") + data = str(self.client.get("tablemodelview/list/").data) + self.assertIn("wb_health_population", data) + self.assertNotIn("birth_names", data) + + def test_gamma_user_schema_access_to_charts(self): + self.login(username="gamma") + data = str(self.client.get("chart/list/").data) + self.assertIn( + "Life Expectancy VS Rural %", data + ) # wb_health_population slice, has access + self.assertIn( + "Parallel Coordinates", data + ) # wb_health_population slice, has access + self.assertNotIn("Girl Name Cloud", data) # birth_names slice, no access + + def test_sqllab_gamma_user_schema_access_to_sqllab(self): + session = db.session + + example_db = session.query(Database).filter_by(database_name="examples").one() + example_db.expose_in_sqllab = True + session.commit() + + OLD_FLASK_GET_SQL_DBS_REQUEST = ( + "databaseasync/api/read?_flt_0_expose_in_sqllab=1&" + "_oc_DatabaseAsync=database_name&_od_DatabaseAsync=asc" + ) + self.login(username="gamma") + databases_json = self.client.get(OLD_FLASK_GET_SQL_DBS_REQUEST).json + self.assertEquals(databases_json["count"], 1) + + arguments = { + "keys": ["none"], + "filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}], + "order_columns": "database_name", + "order_direction": "asc", + "page": 0, + "page_size": -1, + } + NEW_FLASK_GET_SQL_DBS_REQUEST = f"/api/v1/database/?q={prison.dumps(arguments)}" + self.login(username="gamma") + databases_json = self.client.get(NEW_FLASK_GET_SQL_DBS_REQUEST).json + self.assertEquals(databases_json["count"], 1) + self.logout() def assert_can_read(self, view_menu, permissions_set): self.assertIn(("can_show", view_menu), permissions_set) diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 94603dbb8..8b0c45af5 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -888,9 +888,7 @@ class UtilsTestCase(SupersetTestCase): self.assertIsNotNone(database) self.assertEqual(database.sqlalchemy_uri, "sqlite:///superset.db") self.assertIsNotNone( - security_manager.find_permission_view_menu( - "datasource_access", database.perm - ) + security_manager.find_permission_view_menu("database_access", database.perm) ) # Test change URI get_or_create_db("test_db", "sqlite:///changed.db")