diff --git a/.pylintrc b/.pylintrc index 881495719..dbc0eabae 100644 --- a/.pylintrc +++ b/.pylintrc @@ -303,7 +303,7 @@ ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuil # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. -ignored-classes=contextlib.closing,optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session +ignored-classes=contextlib.closing,optparse.Values,thread._local,_thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular diff --git a/superset/annotation_layers/annotations/api.py b/superset/annotation_layers/annotations/api.py index 291c074fa..c0af2dce6 100644 --- a/superset/annotation_layers/annotations/api.py +++ b/superset/annotation_layers/annotations/api.py @@ -17,7 +17,7 @@ import logging from typing import Any, Dict -from flask import g, request, Response +from flask import request, Response from flask_appbuilder.api import expose, permission_name, protect, rison, safe from flask_appbuilder.api.schemas import get_item_schema, get_list_schema from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -306,7 +306,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - new_model = CreateAnnotationCommand(g.user, item).run() + new_model = CreateAnnotationCommand(item).run() return self.response(201, id=new_model.id, result=item) except AnnotationLayerNotFoundError as ex: return self.response_400(message=str(ex)) @@ -381,7 +381,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - new_model = UpdateAnnotationCommand(g.user, annotation_id, item).run() + new_model = UpdateAnnotationCommand(annotation_id, item).run() return self.response(200, id=new_model.id, result=item) except (AnnotationNotFoundError, AnnotationLayerNotFoundError): return self.response_404() @@ -438,7 +438,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - DeleteAnnotationCommand(g.user, annotation_id).run() + DeleteAnnotationCommand(annotation_id).run() return self.response(200, message="OK") except AnnotationNotFoundError: return self.response_404() @@ -495,7 +495,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi): """ item_ids = kwargs["rison"] try: - BulkDeleteAnnotationCommand(g.user, item_ids).run() + BulkDeleteAnnotationCommand(item_ids).run() return self.response( 200, message=ngettext( diff --git a/superset/annotation_layers/annotations/commands/bulk_delete.py b/superset/annotation_layers/annotations/commands/bulk_delete.py index 6a164c877..113725050 100644 --- a/superset/annotation_layers/annotations/commands/bulk_delete.py +++ b/superset/annotation_layers/annotations/commands/bulk_delete.py @@ -17,8 +17,6 @@ import logging from typing import List, Optional -from flask_appbuilder.security.sqla.models import User - from superset.annotation_layers.annotations.commands.exceptions import ( AnnotationBulkDeleteFailedError, AnnotationNotFoundError, @@ -32,8 +30,7 @@ logger = logging.getLogger(__name__) class BulkDeleteAnnotationCommand(BaseCommand): - def __init__(self, user: User, model_ids: List[int]): - self._actor = user + def __init__(self, model_ids: List[int]): self._model_ids = model_ids self._models: Optional[List[Annotation]] = None diff --git a/superset/annotation_layers/annotations/commands/create.py b/superset/annotation_layers/annotations/commands/create.py index d745df121..dcfa6c852 100644 --- a/superset/annotation_layers/annotations/commands/create.py +++ b/superset/annotation_layers/annotations/commands/create.py @@ -19,7 +19,6 @@ from datetime import datetime from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.annotation_layers.annotations.commands.exceptions import ( @@ -38,8 +37,7 @@ logger = logging.getLogger(__name__) class CreateAnnotationCommand(BaseCommand): - def __init__(self, user: User, data: Dict[str, Any]): - self._actor = user + def __init__(self, data: Dict[str, Any]): self._properties = data.copy() def run(self) -> Model: diff --git a/superset/annotation_layers/annotations/commands/delete.py b/superset/annotation_layers/annotations/commands/delete.py index 3d874818d..915f7f80c 100644 --- a/superset/annotation_layers/annotations/commands/delete.py +++ b/superset/annotation_layers/annotations/commands/delete.py @@ -18,7 +18,6 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from superset.annotation_layers.annotations.commands.exceptions import ( AnnotationDeleteFailedError, @@ -33,8 +32,7 @@ logger = logging.getLogger(__name__) class DeleteAnnotationCommand(BaseCommand): - def __init__(self, user: User, model_id: int): - self._actor = user + def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[Annotation] = None diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py index 9e3012acb..c55a1cdaf 100644 --- a/superset/annotation_layers/annotations/commands/update.py +++ b/superset/annotation_layers/annotations/commands/update.py @@ -19,7 +19,6 @@ from datetime import datetime from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.annotation_layers.annotations.commands.exceptions import ( @@ -40,8 +39,7 @@ logger = logging.getLogger(__name__) class UpdateAnnotationCommand(BaseCommand): - def __init__(self, user: User, model_id: int, data: Dict[str, Any]): - self._actor = user + def __init__(self, model_id: int, data: Dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[Annotation] = None diff --git a/superset/annotation_layers/api.py b/superset/annotation_layers/api.py index db3979f66..8ef343cae 100644 --- a/superset/annotation_layers/api.py +++ b/superset/annotation_layers/api.py @@ -17,7 +17,7 @@ import logging from typing import Any -from flask import g, request, Response +from flask import request, Response from flask_appbuilder.api import expose, permission_name, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext @@ -151,7 +151,7 @@ class AnnotationLayerRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - DeleteAnnotationLayerCommand(g.user, pk).run() + DeleteAnnotationLayerCommand(pk).run() return self.response(200, message="OK") except AnnotationLayerNotFoundError: return self.response_404() @@ -216,7 +216,7 @@ class AnnotationLayerRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - new_model = CreateAnnotationLayerCommand(g.user, item).run() + new_model = CreateAnnotationLayerCommand(item).run() return self.response(201, id=new_model.id, result=item) except AnnotationLayerNotFoundError as ex: return self.response_400(message=str(ex)) @@ -288,7 +288,7 @@ class AnnotationLayerRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - new_model = UpdateAnnotationLayerCommand(g.user, pk, item).run() + new_model = UpdateAnnotationLayerCommand(pk, item).run() return self.response(200, id=new_model.id, result=item) except AnnotationLayerNotFoundError: return self.response_404() @@ -346,7 +346,7 @@ class AnnotationLayerRestApi(BaseSupersetModelRestApi): """ item_ids = kwargs["rison"] try: - BulkDeleteAnnotationLayerCommand(g.user, item_ids).run() + BulkDeleteAnnotationLayerCommand(item_ids).run() return self.response( 200, message=ngettext( diff --git a/superset/annotation_layers/commands/bulk_delete.py b/superset/annotation_layers/commands/bulk_delete.py index a828047fd..b9bc17e82 100644 --- a/superset/annotation_layers/commands/bulk_delete.py +++ b/superset/annotation_layers/commands/bulk_delete.py @@ -17,8 +17,6 @@ import logging from typing import List, Optional -from flask_appbuilder.security.sqla.models import User - from superset.annotation_layers.commands.exceptions import ( AnnotationLayerBulkDeleteFailedError, AnnotationLayerBulkDeleteIntegrityError, @@ -33,8 +31,7 @@ logger = logging.getLogger(__name__) class BulkDeleteAnnotationLayerCommand(BaseCommand): - def __init__(self, user: User, model_ids: List[int]): - self._actor = user + def __init__(self, model_ids: List[int]): self._model_ids = model_ids self._models: Optional[List[AnnotationLayer]] = None diff --git a/superset/annotation_layers/commands/create.py b/superset/annotation_layers/commands/create.py index ee42ce755..d5af6c24a 100644 --- a/superset/annotation_layers/commands/create.py +++ b/superset/annotation_layers/commands/create.py @@ -18,7 +18,6 @@ import logging from typing import Any, Dict, List from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.annotation_layers.commands.exceptions import ( @@ -34,8 +33,7 @@ logger = logging.getLogger(__name__) class CreateAnnotationLayerCommand(BaseCommand): - def __init__(self, user: User, data: Dict[str, Any]): - self._actor = user + def __init__(self, data: Dict[str, Any]): self._properties = data.copy() def run(self) -> Model: diff --git a/superset/annotation_layers/commands/delete.py b/superset/annotation_layers/commands/delete.py index c439542b2..3dbd7a574 100644 --- a/superset/annotation_layers/commands/delete.py +++ b/superset/annotation_layers/commands/delete.py @@ -18,7 +18,6 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from superset.annotation_layers.commands.exceptions import ( AnnotationLayerDeleteFailedError, @@ -34,8 +33,7 @@ logger = logging.getLogger(__name__) class DeleteAnnotationLayerCommand(BaseCommand): - def __init__(self, user: User, model_id: int): - self._actor = user + def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[AnnotationLayer] = None diff --git a/superset/annotation_layers/commands/update.py b/superset/annotation_layers/commands/update.py index d2f48abb2..f4a04cdeb 100644 --- a/superset/annotation_layers/commands/update.py +++ b/superset/annotation_layers/commands/update.py @@ -18,7 +18,6 @@ import logging from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.annotation_layers.commands.exceptions import ( @@ -36,8 +35,7 @@ logger = logging.getLogger(__name__) class UpdateAnnotationLayerCommand(BaseCommand): - def __init__(self, user: User, model_id: int, data: Dict[str, Any]): - self._actor = user + def __init__(self, model_id: int, data: Dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[AnnotationLayer] = None diff --git a/superset/charts/api.py b/superset/charts/api.py index 153e95bb2..e7e511d4f 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -21,7 +21,7 @@ from io import BytesIO from typing import Any, Optional from zipfile import ZipFile -from flask import g, redirect, request, Response, send_file, url_for +from flask import redirect, request, Response, send_file, url_for from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.hooks import before_request from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -285,7 +285,7 @@ class ChartRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - new_model = CreateChartCommand(g.user, item).run() + new_model = CreateChartCommand(item).run() return self.response(201, id=new_model.id, result=item) except ChartInvalidError as ex: return self.response_422(message=ex.normalized_messages()) @@ -356,7 +356,7 @@ class ChartRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - changed_model = UpdateChartCommand(g.user, pk, item).run() + changed_model = UpdateChartCommand(pk, item).run() response = self.response(200, id=changed_model.id, result=item) except ChartNotFoundError: response = self.response_404() @@ -416,7 +416,7 @@ class ChartRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - DeleteChartCommand(g.user, pk).run() + DeleteChartCommand(pk).run() return self.response(200, message="OK") except ChartNotFoundError: return self.response_404() @@ -476,7 +476,7 @@ class ChartRestApi(BaseSupersetModelRestApi): """ item_ids = kwargs["rison"] try: - BulkDeleteChartCommand(g.user, item_ids).run() + BulkDeleteChartCommand(item_ids).run() return self.response( 200, message=ngettext( diff --git a/superset/charts/commands/bulk_delete.py b/superset/charts/commands/bulk_delete.py index 26a3fce9e..caf8fe039 100644 --- a/superset/charts/commands/bulk_delete.py +++ b/superset/charts/commands/bulk_delete.py @@ -17,9 +17,9 @@ import logging from typing import List, Optional -from flask_appbuilder.security.sqla.models import User from flask_babel import lazy_gettext as _ +from superset import security_manager from superset.charts.commands.exceptions import ( ChartBulkDeleteFailedError, ChartBulkDeleteFailedReportsExistError, @@ -32,14 +32,12 @@ from superset.commands.exceptions import DeleteFailedError from superset.exceptions import SupersetSecurityException from superset.models.slice import Slice from superset.reports.dao import ReportScheduleDAO -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class BulkDeleteChartCommand(BaseCommand): - def __init__(self, user: User, model_ids: List[int]): - self._actor = user + def __init__(self, model_ids: List[int]): self._model_ids = model_ids self._models: Optional[List[Slice]] = None @@ -66,6 +64,6 @@ class BulkDeleteChartCommand(BaseCommand): # Check ownership for model in self._models: try: - check_ownership(model) + security_manager.raise_for_ownership(model) except SupersetSecurityException as ex: raise ChartForbiddenError() from ex diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py index 34a25aea2..823834079 100644 --- a/superset/charts/commands/create.py +++ b/superset/charts/commands/create.py @@ -18,8 +18,8 @@ import logging from datetime import datetime from typing import Any, Dict, List, Optional +from flask import g from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.charts.commands.exceptions import ( @@ -37,15 +37,14 @@ logger = logging.getLogger(__name__) class CreateChartCommand(CreateMixin, BaseCommand): - def __init__(self, user: User, data: Dict[str, Any]): - self._actor = user + def __init__(self, data: Dict[str, Any]): self._properties = data.copy() def run(self) -> Model: self.validate() try: self._properties["last_saved_at"] = datetime.now() - self._properties["last_saved_by"] = self._actor + self._properties["last_saved_by"] = g.user chart = ChartDAO.create(self._properties) except DAOCreateFailedError as ex: logger.exception(ex.exception) @@ -73,7 +72,7 @@ class CreateChartCommand(CreateMixin, BaseCommand): self._properties["dashboards"] = dashboards try: - owners = self.populate_owners(self._actor, owner_ids) + owners = self.populate_owners(owner_ids) self._properties["owners"] = owners except ValidationError as ex: exceptions.append(ex) diff --git a/superset/charts/commands/delete.py b/superset/charts/commands/delete.py index faf72c5ef..cb6644c71 100644 --- a/superset/charts/commands/delete.py +++ b/superset/charts/commands/delete.py @@ -18,9 +18,9 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from flask_babel import lazy_gettext as _ +from superset import security_manager from superset.charts.commands.exceptions import ( ChartDeleteFailedError, ChartDeleteFailedReportsExistError, @@ -34,14 +34,12 @@ from superset.exceptions import SupersetSecurityException from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.reports.dao import ReportScheduleDAO -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class DeleteChartCommand(BaseCommand): - def __init__(self, user: User, model_id: int): - self._actor = user + def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[Slice] = None @@ -69,6 +67,6 @@ class DeleteChartCommand(BaseCommand): ) # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise ChartForbiddenError() from ex diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index 0355e6e5f..e613222b3 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -18,10 +18,11 @@ import logging from datetime import datetime from typing import Any, Dict, List, Optional +from flask import g from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError +from superset import security_manager from superset.charts.commands.exceptions import ( ChartForbiddenError, ChartInvalidError, @@ -37,7 +38,6 @@ from superset.dao.exceptions import DAOUpdateFailedError from superset.dashboards.dao import DashboardDAO from superset.exceptions import SupersetSecurityException from superset.models.slice import Slice -from superset.views.base import check_ownership logger = logging.getLogger(__name__) @@ -49,8 +49,7 @@ def is_query_context_update(properties: Dict[str, Any]) -> bool: class UpdateChartCommand(UpdateMixin, BaseCommand): - def __init__(self, user: User, model_id: int, data: Dict[str, Any]): - self._actor = user + def __init__(self, model_id: int, data: Dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[Slice] = None @@ -60,7 +59,7 @@ class UpdateChartCommand(UpdateMixin, BaseCommand): try: if self._properties.get("query_context_generation") is None: self._properties["last_saved_at"] = datetime.now() - self._properties["last_saved_by"] = self._actor + self._properties["last_saved_by"] = g.user chart = ChartDAO.update(self._model, self._properties) except DAOUpdateFailedError as ex: logger.exception(ex.exception) @@ -88,8 +87,8 @@ class UpdateChartCommand(UpdateMixin, BaseCommand): # ownership so the update can be performed by report workers if not is_query_context_update(self._properties): try: - check_ownership(self._model) - owners = self.populate_owners(self._actor, owner_ids) + security_manager.raise_for_ownership(self._model) + owners = self.populate_owners(owner_ids) self._properties["owners"] = owners except SupersetSecurityException as ex: raise ChartForbiddenError() from ex diff --git a/superset/commands/base.py b/superset/commands/base.py index 552b95feb..42d595631 100644 --- a/superset/commands/base.py +++ b/superset/commands/base.py @@ -45,34 +45,28 @@ class BaseCommand(ABC): class CreateMixin: # pylint: disable=too-few-public-methods @staticmethod - def populate_owners( - user: User, owner_ids: Optional[List[int]] = None - ) -> List[User]: + def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]: """ Populate list of owners, defaulting to the current user if `owner_ids` is undefined or empty. If current user is missing in `owner_ids`, current user is added unless belonging to the Admin role. - :param user: current user :param owner_ids: list of owners by id's :raises OwnersNotFoundValidationError: if at least one owner can't be resolved :returns: Final list of owners """ - return populate_owners(user, owner_ids, default_to_user=True) + return populate_owners(owner_ids, default_to_user=True) class UpdateMixin: # pylint: disable=too-few-public-methods @staticmethod - def populate_owners( - user: User, owner_ids: Optional[List[int]] = None - ) -> List[User]: + def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]: """ Populate list of owners. If current user is missing in `owner_ids`, current user is added unless belonging to the Admin role. - :param user: current user :param owner_ids: list of owners by id's :raises OwnersNotFoundValidationError: if at least one owner can't be resolved :returns: Final list of owners """ - return populate_owners(user, owner_ids, default_to_user=False) + return populate_owners(owner_ids, default_to_user=False) diff --git a/superset/commands/utils.py b/superset/commands/utils.py index 0be5c52e3..ad58bb405 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -18,8 +18,10 @@ from __future__ import annotations from typing import List, Optional, TYPE_CHECKING +from flask import g from flask_appbuilder.security.sqla.models import Role, User +from superset import security_manager from superset.commands.exceptions import ( DatasourceNotFoundValidationError, OwnersNotFoundValidationError, @@ -27,21 +29,20 @@ from superset.commands.exceptions import ( ) from superset.dao.exceptions import DatasourceNotFound from superset.datasource.dao import DatasourceDAO -from superset.extensions import db, security_manager -from superset.utils.core import DatasourceType +from superset.extensions import db +from superset.utils.core import DatasourceType, get_user_id if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource def populate_owners( - user: User, owner_ids: Optional[List[int]], default_to_user: bool, ) -> List[User]: """ Helper function for commands, will fetch all users from owners id's - :param user: current user + :param owner_ids: list of owners by id's :param default_to_user: make user the owner if `owner_ids` is None or empty :raises OwnersNotFoundValidationError: if at least one owner id can't be resolved @@ -50,12 +51,10 @@ def populate_owners( owner_ids = owner_ids or [] owners = [] if not owner_ids and default_to_user: - return [user] - if user.id not in owner_ids and "admin" not in [ - role.name.lower() for role in user.roles - ]: + return [g.user] + if not (security_manager.is_admin() or get_user_id() in owner_ids): # make sure non-admins can't remove themselves as owner by mistake - owners.append(user) + owners.append(g.user) for owner_id in owner_ids: owner = security_manager.get_user_by_id(owner_id) if not owner: diff --git a/superset/common/request_contexed_based.py b/superset/common/request_contexed_based.py deleted file mode 100644 index 5d8405e36..000000000 --- a/superset/common/request_contexed_based.py +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from superset import conf, security_manager - - -def is_user_admin() -> bool: - user_roles = [role.name.lower() for role in security_manager.get_user_roles()] - admin_role = conf.get("AUTH_ROLE_ADMIN").lower() - return admin_role in user_roles diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index 2f79f9cf5..18533905e 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -60,9 +60,10 @@ class SelectDataRequired(DataRequired): # pylint: disable=too-few-public-method field_flags = () -class TableColumnInlineView( - CompactCRUDMixin, SupersetModelView -): # pylint: disable=too-many-ancestors +class TableColumnInlineView( # pylint: disable=too-many-ancestors + CompactCRUDMixin, + SupersetModelView, +): datamodel = SQLAInterface(models.TableColumn) # TODO TODO, review need for this on related_views class_permission_name = "Dataset" @@ -194,9 +195,10 @@ class TableColumnInlineView( edit_form_extra_fields = add_form_extra_fields -class SqlMetricInlineView( - CompactCRUDMixin, SupersetModelView -): # pylint: disable=too-many-ancestors +class SqlMetricInlineView( # pylint: disable=too-many-ancestors + CompactCRUDMixin, + SupersetModelView, +): datamodel = SQLAInterface(models.SqlMetric) class_permission_name = "Dataset" method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP @@ -278,9 +280,9 @@ class RowLevelSecurityListWidget( super().__init__(**kwargs) -class RowLevelSecurityFiltersModelView( +class RowLevelSecurityFiltersModelView( # pylint: disable=too-many-ancestors SupersetModelView, DeleteMixin -): # pylint: disable=too-many-ancestors +): datamodel = SQLAInterface(models.RowLevelSecurityFilter) list_widget = cast(SupersetListWidget, RowLevelSecurityListWidget) diff --git a/superset/css_templates/api.py b/superset/css_templates/api.py index 5cc36f400..ae367985e 100644 --- a/superset/css_templates/api.py +++ b/superset/css_templates/api.py @@ -17,7 +17,7 @@ import logging from typing import Any -from flask import g, Response +from flask import Response from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext @@ -130,7 +130,7 @@ class CssTemplateRestApi(BaseSupersetModelRestApi): """ item_ids = kwargs["rison"] try: - BulkDeleteCssTemplateCommand(g.user, item_ids).run() + BulkDeleteCssTemplateCommand(item_ids).run() return self.response( 200, message=ngettext( diff --git a/superset/css_templates/commands/bulk_delete.py b/superset/css_templates/commands/bulk_delete.py index 839dbd26a..93564208c 100644 --- a/superset/css_templates/commands/bulk_delete.py +++ b/superset/css_templates/commands/bulk_delete.py @@ -17,8 +17,6 @@ import logging from typing import List, Optional -from flask_appbuilder.security.sqla.models import User - from superset.commands.base import BaseCommand from superset.css_templates.commands.exceptions import ( CssTemplateBulkDeleteFailedError, @@ -32,8 +30,7 @@ logger = logging.getLogger(__name__) class BulkDeleteCssTemplateCommand(BaseCommand): - def __init__(self, user: User, model_ids: List[int]): - self._actor = user + def __init__(self, model_ids: List[int]): self._model_ids = model_ids self._models: Optional[List[CssTemplate]] = None diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index ce8b75441..9a323923c 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -23,7 +23,7 @@ from io import BytesIO from typing import Any, Callable, Optional from zipfile import is_zipfile, ZipFile -from flask import g, make_response, redirect, request, Response, send_file, url_for +from flask import make_response, redirect, request, Response, send_file, url_for from flask_appbuilder import permission_name from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.hooks import before_request @@ -504,7 +504,7 @@ class DashboardRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - new_model = CreateDashboardCommand(g.user, item).run() + new_model = CreateDashboardCommand(item).run() return self.response(201, id=new_model.id, result=item) except DashboardInvalidError as ex: return self.response_422(message=ex.normalized_messages()) @@ -577,7 +577,7 @@ class DashboardRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - changed_model = UpdateDashboardCommand(g.user, pk, item).run() + changed_model = UpdateDashboardCommand(pk, item).run() last_modified_time = changed_model.changed_on.replace( microsecond=0 ).timestamp() @@ -644,7 +644,7 @@ class DashboardRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - DeleteDashboardCommand(g.user, pk).run() + DeleteDashboardCommand(pk).run() return self.response(200, message="OK") except DashboardNotFoundError: return self.response_404() @@ -704,7 +704,7 @@ class DashboardRestApi(BaseSupersetModelRestApi): """ item_ids = kwargs["rison"] try: - BulkDeleteDashboardCommand(g.user, item_ids).run() + BulkDeleteDashboardCommand(item_ids).run() return self.response( 200, message=ngettext( @@ -942,6 +942,7 @@ class DashboardRestApi(BaseSupersetModelRestApi): dashboards = DashboardDAO.find_by_ids(requested_ids) if not dashboards: return self.response_404() + favorited_dashboard_ids = DashboardDAO.favorited_ids(dashboards) res = [ {"id": request_id, "value": request_id in favorited_dashboard_ids} diff --git a/superset/dashboards/commands/bulk_delete.py b/superset/dashboards/commands/bulk_delete.py index 958dea27d..52f599843 100644 --- a/superset/dashboards/commands/bulk_delete.py +++ b/superset/dashboards/commands/bulk_delete.py @@ -17,9 +17,9 @@ import logging from typing import List, Optional -from flask_appbuilder.security.sqla.models import User from flask_babel import lazy_gettext as _ +from superset import security_manager from superset.commands.base import BaseCommand from superset.commands.exceptions import DeleteFailedError from superset.dashboards.commands.exceptions import ( @@ -32,14 +32,12 @@ from superset.dashboards.dao import DashboardDAO from superset.exceptions import SupersetSecurityException from superset.models.dashboard import Dashboard from superset.reports.dao import ReportScheduleDAO -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class BulkDeleteDashboardCommand(BaseCommand): - def __init__(self, user: User, model_ids: List[int]): - self._actor = user + def __init__(self, model_ids: List[int]): self._model_ids = model_ids self._models: Optional[List[Dashboard]] = None @@ -67,6 +65,6 @@ class BulkDeleteDashboardCommand(BaseCommand): # Check ownership for model in self._models: try: - check_ownership(model) + security_manager.raise_for_ownership(model) except SupersetSecurityException as ex: raise DashboardForbiddenError() from ex diff --git a/superset/dashboards/commands/create.py b/superset/dashboards/commands/create.py index 1e796bc31..811508c2e 100644 --- a/superset/dashboards/commands/create.py +++ b/superset/dashboards/commands/create.py @@ -18,7 +18,6 @@ import logging from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.commands.base import BaseCommand, CreateMixin @@ -35,8 +34,7 @@ logger = logging.getLogger(__name__) class CreateDashboardCommand(CreateMixin, BaseCommand): - def __init__(self, user: User, data: Dict[str, Any]): - self._actor = user + def __init__(self, data: Dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -60,7 +58,7 @@ class CreateDashboardCommand(CreateMixin, BaseCommand): exceptions.append(DashboardSlugExistsValidationError()) try: - owners = self.populate_owners(self._actor, owner_ids) + owners = self.populate_owners(owner_ids) self._properties["owners"] = owners except ValidationError as ex: exceptions.append(ex) diff --git a/superset/dashboards/commands/delete.py b/superset/dashboards/commands/delete.py index 67f683a1c..7af2fdf4c 100644 --- a/superset/dashboards/commands/delete.py +++ b/superset/dashboards/commands/delete.py @@ -18,9 +18,9 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from flask_babel import lazy_gettext as _ +from superset import security_manager from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError from superset.dashboards.commands.exceptions import ( @@ -33,14 +33,12 @@ from superset.dashboards.dao import DashboardDAO from superset.exceptions import SupersetSecurityException from superset.models.dashboard import Dashboard from superset.reports.dao import ReportScheduleDAO -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class DeleteDashboardCommand(BaseCommand): - def __init__(self, user: User, model_id: int): - self._actor = user + def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[Dashboard] = None @@ -67,6 +65,6 @@ class DeleteDashboardCommand(BaseCommand): ) # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DashboardForbiddenError() from ex diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py index 2065437cc..8b6704b04 100644 --- a/superset/dashboards/commands/update.py +++ b/superset/dashboards/commands/update.py @@ -19,9 +19,9 @@ import logging from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError +from superset import security_manager from superset.commands.base import BaseCommand, UpdateMixin from superset.commands.utils import populate_roles from superset.dao.exceptions import DAOUpdateFailedError @@ -36,14 +36,12 @@ from superset.dashboards.dao import DashboardDAO from superset.exceptions import SupersetSecurityException from superset.extensions import db from superset.models.dashboard import Dashboard -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class UpdateDashboardCommand(UpdateMixin, BaseCommand): - def __init__(self, user: User, model_id: int, data: Dict[str, Any]): - self._actor = user + def __init__(self, model_id: int, data: Dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[Dashboard] = None @@ -77,7 +75,7 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): raise DashboardNotFoundError() # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DashboardForbiddenError() from ex @@ -89,7 +87,7 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): if owners_ids is None: owners_ids = [owner.id for owner in self._model.owners] try: - owners = self.populate_owners(self._actor, owners_ids) + owners = self.populate_owners(owners_ids) self._properties["owners"] = owners except ValidationError as ex: exceptions.append(ex) diff --git a/superset/dashboards/filter_sets/api.py b/superset/dashboards/filter_sets/api.py index 3dc2a28de..109ae73f0 100644 --- a/superset/dashboards/filter_sets/api.py +++ b/superset/dashboards/filter_sets/api.py @@ -17,7 +17,7 @@ import logging from typing import Any, cast -from flask import g, request, Response +from flask import request, Response from flask_appbuilder.api import ( expose, get_list_schema, @@ -243,7 +243,7 @@ class FilterSetRestApi(BaseSupersetModelRestApi): """ try: item = self.add_model_schema.load(request.json) - new_model = CreateFilterSetCommand(g.user, dashboard_id, item).run() + new_model = CreateFilterSetCommand(dashboard_id, item).run() return self.response( 201, **self.show_model_schema.dump(new_model, many=False) ) @@ -314,7 +314,7 @@ class FilterSetRestApi(BaseSupersetModelRestApi): """ try: item = self.edit_model_schema.load(request.json) - changed_model = UpdateFilterSetCommand(g.user, dashboard_id, pk, item).run() + changed_model = UpdateFilterSetCommand(dashboard_id, pk, item).run() return self.response( 200, **self.show_model_schema.dump(changed_model, many=False) ) @@ -374,7 +374,7 @@ class FilterSetRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - changed_model = DeleteFilterSetCommand(g.user, dashboard_id, pk).run() + changed_model = DeleteFilterSetCommand(dashboard_id, pk).run() return self.response(200, id=changed_model.id) except ValidationError as error: return self.response_400(message=error.messages) diff --git a/superset/dashboards/filter_sets/commands/base.py b/superset/dashboards/filter_sets/commands/base.py index 0e902e5e6..e6a4b03e3 100644 --- a/superset/dashboards/filter_sets/commands/base.py +++ b/superset/dashboards/filter_sets/commands/base.py @@ -18,10 +18,9 @@ import logging from typing import cast, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User +from superset import security_manager from superset.common.not_authrized_object import NotAuthorizedException -from superset.common.request_contexed_based import is_user_admin from superset.dashboards.commands.exceptions import DashboardNotFoundError from superset.dashboards.dao import DashboardDAO from superset.dashboards.filter_sets.commands.exceptions import ( @@ -31,6 +30,7 @@ from superset.dashboards.filter_sets.commands.exceptions import ( from superset.dashboards.filter_sets.consts import USER_OWNER_TYPE from superset.models.dashboard import Dashboard from superset.models.filter_set import FilterSet +from superset.utils.core import get_user_id logger = logging.getLogger(__name__) @@ -41,9 +41,7 @@ class BaseFilterSetCommand: _filter_set_id: Optional[int] _filter_set: Optional[FilterSet] - def __init__(self, user: User, dashboard_id: int): - self._actor = user - self._is_actor_admin = is_user_admin() + def __init__(self, dashboard_id: int): self._dashboard_id = dashboard_id def run(self) -> Model: @@ -54,9 +52,6 @@ class BaseFilterSetCommand: if not self._dashboard: raise DashboardNotFoundError() - def is_user_dashboard_owner(self) -> bool: - return self._is_actor_admin or self._dashboard.is_actor_owner() - def validate_exist_filter_use_cases_set(self) -> None: # pylint: disable=C0103 self._validate_filter_set_exists_and_set_when_exists() self.check_ownership() @@ -70,15 +65,15 @@ class BaseFilterSetCommand: def check_ownership(self) -> None: try: - if not self._is_actor_admin: + if not security_manager.is_admin(): filter_set: FilterSet = cast(FilterSet, self._filter_set) if filter_set.owner_type == USER_OWNER_TYPE: - if self._actor.id != filter_set.owner_id: + if get_user_id() != filter_set.owner_id: raise FilterSetForbiddenError( str(self._filter_set_id), "The user is not the owner of the filter_set", ) - elif not self.is_user_dashboard_owner(): + elif not security_manager.is_owner(self._dashboard): raise FilterSetForbiddenError( str(self._filter_set_id), "The user is not an owner of the filter_set's dashboard", diff --git a/superset/dashboards/filter_sets/commands/create.py b/superset/dashboards/filter_sets/commands/create.py index b74e6d304..de1d70daf 100644 --- a/superset/dashboards/filter_sets/commands/create.py +++ b/superset/dashboards/filter_sets/commands/create.py @@ -17,9 +17,7 @@ import logging from typing import Any, Dict -from flask import g from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from superset import security_manager from superset.dashboards.filter_sets.commands.base import BaseFilterSetCommand @@ -35,14 +33,15 @@ from superset.dashboards.filter_sets.consts import ( OWNER_TYPE_FIELD, ) from superset.dashboards.filter_sets.dao import FilterSetDAO +from superset.utils.core import get_user_id logger = logging.getLogger(__name__) class CreateFilterSetCommand(BaseFilterSetCommand): # pylint: disable=C0103 - def __init__(self, user: User, dashboard_id: int, data: Dict[str, Any]): - super().__init__(user, dashboard_id) + def __init__(self, dashboard_id: int, data: Dict[str, Any]): + super().__init__(dashboard_id) self._properties = data.copy() def run(self) -> Model: @@ -61,13 +60,13 @@ class CreateFilterSetCommand(BaseFilterSetCommand): def _validate_owner_id_exists(self) -> None: owner_id = self._properties[OWNER_ID_FIELD] - if not (g.user.id == owner_id or security_manager.get_user_by_id(owner_id)): + if not (get_user_id() == owner_id or security_manager.get_user_by_id(owner_id)): raise FilterSetCreateFailedError( str(self._dashboard_id), "owner_id does not exists" ) def _validate_user_is_the_dashboard_owner(self) -> None: - if not self.is_user_dashboard_owner(): + if not security_manager.is_owner(self._dashboard): raise UserIsNotDashboardOwnerError(str(self._dashboard_id)) def _validate_owner_id_is_dashboard_id(self) -> None: diff --git a/superset/dashboards/filter_sets/commands/delete.py b/superset/dashboards/filter_sets/commands/delete.py index 18d7fed8f..c41625279 100644 --- a/superset/dashboards/filter_sets/commands/delete.py +++ b/superset/dashboards/filter_sets/commands/delete.py @@ -17,7 +17,6 @@ import logging from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from superset.dao.exceptions import DAODeleteFailedError from superset.dashboards.filter_sets.commands.base import BaseFilterSetCommand @@ -32,8 +31,8 @@ logger = logging.getLogger(__name__) class DeleteFilterSetCommand(BaseFilterSetCommand): - def __init__(self, user: User, dashboard_id: int, filter_set_id: int): - super().__init__(user, dashboard_id) + def __init__(self, dashboard_id: int, filter_set_id: int): + super().__init__(dashboard_id) self._filter_set_id = filter_set_id def run(self) -> Model: diff --git a/superset/dashboards/filter_sets/commands/update.py b/superset/dashboards/filter_sets/commands/update.py index d2c43f085..07d59f93a 100644 --- a/superset/dashboards/filter_sets/commands/update.py +++ b/superset/dashboards/filter_sets/commands/update.py @@ -18,7 +18,6 @@ import logging from typing import Any, Dict from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from superset.dao.exceptions import DAOUpdateFailedError from superset.dashboards.filter_sets.commands.base import BaseFilterSetCommand @@ -32,10 +31,8 @@ logger = logging.getLogger(__name__) class UpdateFilterSetCommand(BaseFilterSetCommand): - def __init__( - self, user: User, dashboard_id: int, filter_set_id: int, data: Dict[str, Any] - ): - super().__init__(user, dashboard_id) + def __init__(self, dashboard_id: int, filter_set_id: int, data: Dict[str, Any]): + super().__init__(dashboard_id) self._filter_set_id = filter_set_id self._properties = data.copy() diff --git a/superset/dashboards/filter_sets/filters.py b/superset/dashboards/filter_sets/filters.py index 0083f40d1..3578e8b0b 100644 --- a/superset/dashboards/filter_sets/filters.py +++ b/superset/dashboards/filter_sets/filters.py @@ -18,13 +18,14 @@ from __future__ import annotations from typing import Any, TYPE_CHECKING -from flask import g from sqlalchemy import and_, or_ +from superset import security_manager from superset.dashboards.filter_sets.consts import DASHBOARD_OWNER_TYPE, USER_OWNER_TYPE from superset.models.dashboard import dashboard_user from superset.models.filter_set import FilterSet -from superset.views.base import BaseFilter, is_user_admin +from superset.utils.core import get_user_id +from superset.views.base import BaseFilter if TYPE_CHECKING: from sqlalchemy.orm.query import Query @@ -32,9 +33,8 @@ if TYPE_CHECKING: class FilterSetFilter(BaseFilter): # pylint: disable=too-few-public-methods) def apply(self, query: Query, value: Any) -> Query: - if is_user_admin(): + if security_manager.is_admin(): return query - current_user_id = g.user.id filter_set_ids_by_dashboard_owners = ( # pylint: disable=C0103 query.from_self(FilterSet.id) @@ -42,7 +42,7 @@ class FilterSetFilter(BaseFilter): # pylint: disable=too-few-public-methods) .filter( and_( FilterSet.owner_type == DASHBOARD_OWNER_TYPE, - dashboard_user.c.user_id == current_user_id, + dashboard_user.c.user_id == get_user_id(), ) ) ) @@ -51,7 +51,7 @@ class FilterSetFilter(BaseFilter): # pylint: disable=too-few-public-methods) or_( and_( FilterSet.owner_type == USER_OWNER_TYPE, - FilterSet.owner_id == current_user_id, + FilterSet.owner_id == get_user_id(), ), FilterSet.id.in_(filter_set_ids_by_dashboard_owners), ) diff --git a/superset/dashboards/filter_state/commands/create.py b/superset/dashboards/filter_state/commands/create.py index 18dff8928..48b5e4f5c 100644 --- a/superset/dashboards/filter_state/commands/create.py +++ b/superset/dashboards/filter_state/commands/create.py @@ -20,17 +20,17 @@ from flask import session from superset.dashboards.filter_state.commands.utils import check_access from superset.extensions import cache_manager -from superset.key_value.utils import get_owner, random_key +from superset.key_value.utils import random_key from superset.temporary_cache.commands.create import CreateTemporaryCacheCommand from superset.temporary_cache.commands.entry import Entry from superset.temporary_cache.commands.parameters import CommandParameters from superset.temporary_cache.utils import cache_key +from superset.utils.core import get_user_id class CreateFilterStateCommand(CreateTemporaryCacheCommand): def create(self, cmd_params: CommandParameters) -> str: resource_id = cmd_params.resource_id - actor = cmd_params.actor tab_id = cmd_params.tab_id contextual_key = cache_key(session.get("_id"), tab_id, resource_id) key = cache_manager.filter_state_cache.get(contextual_key) @@ -38,7 +38,7 @@ class CreateFilterStateCommand(CreateTemporaryCacheCommand): key = random_key() value = cast(str, cmd_params.value) # schema ensures that value is not optional check_access(resource_id) - entry: Entry = {"owner": get_owner(actor), "value": value} + entry: Entry = {"owner": get_user_id(), "value": value} cache_manager.filter_state_cache.set(cache_key(resource_id, key), entry) cache_manager.filter_state_cache.set(contextual_key, key) return key diff --git a/superset/dashboards/filter_state/commands/delete.py b/superset/dashboards/filter_state/commands/delete.py index 3ddc08fc5..6086388a8 100644 --- a/superset/dashboards/filter_state/commands/delete.py +++ b/superset/dashboards/filter_state/commands/delete.py @@ -18,23 +18,22 @@ from flask import session from superset.dashboards.filter_state.commands.utils import check_access from superset.extensions import cache_manager -from superset.key_value.utils import get_owner from superset.temporary_cache.commands.delete import DeleteTemporaryCacheCommand from superset.temporary_cache.commands.entry import Entry from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError from superset.temporary_cache.commands.parameters import CommandParameters from superset.temporary_cache.utils import cache_key +from superset.utils.core import get_user_id class DeleteFilterStateCommand(DeleteTemporaryCacheCommand): def delete(self, cmd_params: CommandParameters) -> bool: resource_id = cmd_params.resource_id - actor = cmd_params.actor key = cache_key(resource_id, cmd_params.key) check_access(resource_id) entry: Entry = cache_manager.filter_state_cache.get(key) if entry: - if entry["owner"] != get_owner(actor): + if entry["owner"] != get_user_id(): raise TemporaryCacheAccessDeniedError() tab_id = cmd_params.tab_id contextual_key = cache_key(session.get("_id"), tab_id, resource_id) diff --git a/superset/dashboards/filter_state/commands/update.py b/superset/dashboards/filter_state/commands/update.py index 7f150aae6..c1dc529cc 100644 --- a/superset/dashboards/filter_state/commands/update.py +++ b/superset/dashboards/filter_state/commands/update.py @@ -20,23 +20,23 @@ from flask import session from superset.dashboards.filter_state.commands.utils import check_access from superset.extensions import cache_manager -from superset.key_value.utils import get_owner, random_key +from superset.key_value.utils import random_key from superset.temporary_cache.commands.entry import Entry from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError from superset.temporary_cache.commands.parameters import CommandParameters from superset.temporary_cache.commands.update import UpdateTemporaryCacheCommand from superset.temporary_cache.utils import cache_key +from superset.utils.core import get_user_id class UpdateFilterStateCommand(UpdateTemporaryCacheCommand): def update(self, cmd_params: CommandParameters) -> Optional[str]: resource_id = cmd_params.resource_id - actor = cmd_params.actor key = cmd_params.key value = cast(str, cmd_params.value) # schema ensures that value is not optional check_access(resource_id) entry: Entry = cache_manager.filter_state_cache.get(cache_key(resource_id, key)) - owner = get_owner(actor) + owner = get_user_id() if entry: if entry["owner"] != owner: raise TemporaryCacheAccessDeniedError() diff --git a/superset/dashboards/filters.py b/superset/dashboards/filters.py index 2cc9d3b9a..f765bc8ff 100644 --- a/superset/dashboards/filters.py +++ b/superset/dashboards/filters.py @@ -30,7 +30,7 @@ from superset.models.embedded_dashboard import EmbeddedDashboard from superset.models.slice import Slice from superset.security.guest_token import GuestTokenResourceType, GuestUser from superset.utils.core import get_user_id -from superset.views.base import BaseFilter, is_user_admin +from superset.views.base import BaseFilter from superset.views.base_api import BaseFavoriteFilter @@ -98,7 +98,7 @@ class DashboardAccessFilter(BaseFilter): # pylint: disable=too-few-public-metho """ def apply(self, query: Query, value: Any) -> Query: - if is_user_admin(): + if security_manager.is_admin(): return query datasource_perms = security_manager.user_view_menu_names("datasource_access") diff --git a/superset/dashboards/permalink/api.py b/superset/dashboards/permalink/api.py index ca536af8f..56b6ca311 100644 --- a/superset/dashboards/permalink/api.py +++ b/superset/dashboards/permalink/api.py @@ -16,7 +16,7 @@ # under the License. import logging -from flask import g, request, Response +from flask import request, Response from flask_appbuilder.api import BaseApi, expose, protect, safe from marshmallow import ValidationError @@ -104,7 +104,6 @@ class DashboardPermalinkRestApi(BaseApi): try: state = self.add_model_schema.load(request.json) key = CreateDashboardPermalinkCommand( - actor=g.user, dashboard_id=pk, state=state, ).run() @@ -162,7 +161,7 @@ class DashboardPermalinkRestApi(BaseApi): $ref: '#/components/responses/500' """ try: - value = GetDashboardPermalinkCommand(actor=g.user, key=key).run() + value = GetDashboardPermalinkCommand(key=key).run() if not value: return self.response_404() return self.response(200, **value) diff --git a/superset/dashboards/permalink/commands/create.py b/superset/dashboards/permalink/commands/create.py index 4ffd41104..b8cbdbd3a 100644 --- a/superset/dashboards/permalink/commands/create.py +++ b/superset/dashboards/permalink/commands/create.py @@ -16,7 +16,6 @@ # under the License. import logging -from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError from superset.dashboards.dao import DashboardDAO @@ -32,11 +31,9 @@ logger = logging.getLogger(__name__) class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand): def __init__( self, - actor: User, dashboard_id: str, state: DashboardPermalinkState, ): - self.actor = actor self.dashboard_id = dashboard_id self.state = state @@ -49,7 +46,6 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand): "state": self.state, } key = CreateKeyValueCommand( - actor=self.actor, resource=self.resource, value=value, ).run() diff --git a/superset/dashboards/permalink/commands/get.py b/superset/dashboards/permalink/commands/get.py index 24bf77834..f89f9444e 100644 --- a/superset/dashboards/permalink/commands/get.py +++ b/superset/dashboards/permalink/commands/get.py @@ -17,7 +17,6 @@ import logging from typing import Optional -from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError from superset.dashboards.commands.exceptions import DashboardNotFoundError @@ -33,8 +32,7 @@ logger = logging.getLogger(__name__) class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand): - def __init__(self, actor: User, key: str): - self.actor = actor + def __init__(self, key: str): self.key = key def run(self) -> Optional[DashboardPermalinkValue]: diff --git a/superset/databases/api.py b/superset/databases/api.py index 1afa71c6f..ca8e38aaf 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -22,7 +22,7 @@ from io import BytesIO from typing import Any, Dict, List, Optional from zipfile import ZipFile -from flask import g, request, Response, send_file +from flask import request, Response, send_file from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from marshmallow import ValidationError @@ -261,7 +261,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - new_model = CreateDatabaseCommand(g.user, item).run() + new_model = CreateDatabaseCommand(item).run() # Return censored version for sqlalchemy URI item["sqlalchemy_uri"] = new_model.sqlalchemy_uri item["expose_in_sqllab"] = new_model.expose_in_sqllab @@ -342,7 +342,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - changed_model = UpdateDatabaseCommand(g.user, pk, item).run() + changed_model = UpdateDatabaseCommand(pk, item).run() # Return censored version for sqlalchemy URI item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri if changed_model.parameters: @@ -404,7 +404,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - DeleteDatabaseCommand(g.user, pk).run() + DeleteDatabaseCommand(pk).run() return self.response(200, message="OK") except DatabaseNotFoundError: return self.response_404() @@ -706,7 +706,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): # This validates custom Schema with custom validations except ValidationError as error: return self.response_400(message=error.messages) - TestConnectionDatabaseCommand(g.user, item).run() + TestConnectionDatabaseCommand(item).run() return self.response(200, message="OK") @expose("//related_objects/", methods=["GET"]) @@ -1174,6 +1174,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi): ] raise InvalidParametersError(errors) from ex - command = ValidateDatabaseParametersCommand(g.user, payload) + command = ValidateDatabaseParametersCommand(payload) command.run() return self.response(200, message="OK") diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index e91ccec45..658dbb535 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -18,7 +18,6 @@ import logging from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.commands.base import BaseCommand @@ -38,8 +37,7 @@ logger = logging.getLogger(__name__) class CreateDatabaseCommand(BaseCommand): - def __init__(self, user: User, data: Dict[str, Any]): - self._actor = user + def __init__(self, data: Dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -47,7 +45,7 @@ class CreateDatabaseCommand(BaseCommand): try: # Test connection before starting create transaction - TestConnectionDatabaseCommand(self._actor, self._properties).run() + TestConnectionDatabaseCommand(self._properties).run() except Exception as ex: event_logger.log_with_context( action=f"db_creation_failed.{ex.__class__.__name__}", diff --git a/superset/databases/commands/delete.py b/superset/databases/commands/delete.py index 61bd7ad0a..ebdd54357 100644 --- a/superset/databases/commands/delete.py +++ b/superset/databases/commands/delete.py @@ -18,7 +18,6 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from flask_babel import lazy_gettext as _ from superset.commands.base import BaseCommand @@ -37,8 +36,7 @@ logger = logging.getLogger(__name__) class DeleteDatabaseCommand(BaseCommand): - def __init__(self, user: User, model_id: int): - self._actor = user + def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[Database] = None diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 2e217ab01..9f46a71e3 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -20,7 +20,6 @@ from contextlib import closing from typing import Any, Dict, Optional from flask import current_app as app -from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as _ from func_timeout import func_timeout, FunctionTimedOut from sqlalchemy.engine import Engine @@ -39,14 +38,12 @@ from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import SupersetSecurityException, SupersetTimeoutException from superset.extensions import event_logger from superset.models.core import Database -from superset.utils.core import override_user logger = logging.getLogger(__name__) class TestConnectionDatabaseCommand(BaseCommand): - def __init__(self, user: User, data: Dict[str, Any]): - self._actor = user + def __init__(self, data: Dict[str, Any]): self._properties = data.copy() self._model: Optional[Database] = None @@ -77,47 +74,41 @@ class TestConnectionDatabaseCommand(BaseCommand): database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - with override_user(self._actor): - engine = database.get_sqla_engine() - event_logger.log_with_context( - action="test_connection_attempt", - engine=database.db_engine_spec.__name__, + engine = database.get_sqla_engine() + event_logger.log_with_context( + action="test_connection_attempt", + engine=database.db_engine_spec.__name__, + ) + + def ping(engine: Engine) -> bool: + with closing(engine.raw_connection()) as conn: + return engine.dialect.do_ping(conn) + + try: + alive = func_timeout( + int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()), + ping, + args=(engine,), ) - - def ping(engine: Engine) -> bool: - with closing(engine.raw_connection()) as conn: - return engine.dialect.do_ping(conn) - - try: - alive = func_timeout( - int( - app.config[ - "TEST_DATABASE_CONNECTION_TIMEOUT" - ].total_seconds() - ), - ping, - args=(engine,), - ) - - except (sqlite3.ProgrammingError, RuntimeError): - # SQLite can't run on a separate thread, so ``func_timeout`` fails - # RuntimeError catches the equivalent error from duckdb. - alive = engine.dialect.do_ping(engine) - except FunctionTimedOut as ex: - raise SupersetTimeoutException( - error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, - message=( - "Please check your connection details and database settings, " - "and ensure that your database is accepting connections, " - "then try connecting again." - ), - level=ErrorLevel.ERROR, - extra={"sqlalchemy_uri": database.sqlalchemy_uri}, - ) from ex - except Exception: # pylint: disable=broad-except - alive = False - if not alive: - raise DBAPIError(None, None, None) + except (sqlite3.ProgrammingError, RuntimeError): + # SQLite can't run on a separate thread, so ``func_timeout`` fails + # RuntimeError catches the equivalent error from duckdb. + alive = engine.dialect.do_ping(engine) + except FunctionTimedOut as ex: + raise SupersetTimeoutException( + error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, + message=( + "Please check your connection details and database settings, " + "and ensure that your database is accepting connections, " + "then try connecting again." + ), + level=ErrorLevel.ERROR, + extra={"sqlalchemy_uri": database.sqlalchemy_uri}, + ) from ex + except Exception: # pylint: disable=broad-except + alive = False + if not alive: + raise DBAPIError(None, None, None) # Log succesful connection test with engine event_logger.log_with_context( diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index 69b6c30e7..fadc8ba25 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -18,7 +18,6 @@ import logging from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.commands.base import BaseCommand @@ -38,8 +37,7 @@ logger = logging.getLogger(__name__) class UpdateDatabaseCommand(BaseCommand): - def __init__(self, user: User, model_id: int, data: Dict[str, Any]): - self._actor = user + def __init__(self, model_id: int, data: Dict[str, Any]): self._properties = data.copy() self._model_id = model_id self._model: Optional[Database] = None diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index 145965fc6..a9f1633a1 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -18,7 +18,6 @@ import json from contextlib import closing from typing import Any, Dict, Optional -from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __ from superset.commands.base import BaseCommand @@ -35,14 +34,12 @@ from superset.db_engine_specs.base import BasicParametersMixin from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.extensions import event_logger from superset.models.core import Database -from superset.utils.core import override_user BYPASS_VALIDATION_ENGINES = {"bigquery"} class ValidateDatabaseParametersCommand(BaseCommand): - def __init__(self, user: User, parameters: Dict[str, Any]): - self._actor = user + def __init__(self, parameters: Dict[str, Any]): self._properties = parameters.copy() self._model: Optional[Database] = None @@ -117,22 +114,21 @@ class ValidateDatabaseParametersCommand(BaseCommand): database.set_sqlalchemy_uri(sqlalchemy_uri) database.db_engine_spec.mutate_db_for_connection_test(database) - with override_user(self._actor): - engine = database.get_sqla_engine() - try: - with closing(engine.raw_connection()) as conn: - alive = engine.dialect.do_ping(conn) - except Exception as ex: - url = make_url_safe(sqlalchemy_uri) - context = { - "hostname": url.host, - "password": url.password, - "port": url.port, - "username": url.username, - "database": url.database, - } - errors = database.db_engine_spec.extract_errors(ex, context) - raise DatabaseTestConnectionFailedError(errors) from ex + engine = database.get_sqla_engine() + try: + with closing(engine.raw_connection()) as conn: + alive = engine.dialect.do_ping(conn) + except Exception as ex: + url = make_url_safe(sqlalchemy_uri) + context = { + "hostname": url.host, + "password": url.password, + "port": url.port, + "username": url.username, + "database": url.database, + } + errors = database.db_engine_spec.extract_errors(ex, context) + raise DatabaseTestConnectionFailedError(errors) from ex if not alive: raise DatabaseOfflineError( diff --git a/superset/datasets/api.py b/superset/datasets/api.py index db6136865..6c4c896ae 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -23,7 +23,7 @@ from zipfile import is_zipfile, ZipFile import simplejson import yaml -from flask import g, make_response, request, Response, send_file +from flask import make_response, request, Response, send_file from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext @@ -264,7 +264,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): return self.response_400(message=error.messages) try: - new_model = CreateDatasetCommand(g.user, item).run() + new_model = CreateDatasetCommand(item).run() return self.response(201, id=new_model.id, result=item) except DatasetInvalidError as ex: return self.response_422(message=ex.normalized_messages()) @@ -344,11 +344,9 @@ class DatasetRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - changed_model = UpdateDatasetCommand( - g.user, pk, item, override_columns - ).run() + changed_model = UpdateDatasetCommand(pk, item, override_columns).run() if override_columns: - RefreshDatasetCommand(g.user, pk).run() + RefreshDatasetCommand(pk).run() response = self.response(200, id=changed_model.id, result=item) except DatasetNotFoundError: response = self.response_404() @@ -407,7 +405,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - DeleteDatasetCommand(g.user, pk).run() + DeleteDatasetCommand(pk).run() return self.response(200, message="OK") except DatasetNotFoundError: return self.response_404() @@ -547,7 +545,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - RefreshDatasetCommand(g.user, pk).run() + RefreshDatasetCommand(pk).run() return self.response(200, message="OK") except DatasetNotFoundError: return self.response_404() @@ -671,7 +669,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): """ item_ids = kwargs["rison"] try: - BulkDeleteDatasetCommand(g.user, item_ids).run() + BulkDeleteDatasetCommand(item_ids).run() return self.response( 200, message=ngettext( @@ -812,7 +810,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): """ try: force = parse_boolean_string(request.args.get("force")) - rv = SamplesDatasetCommand(g.user, pk, force).run() + rv = SamplesDatasetCommand(pk, force).run() response_data = simplejson.dumps( {"result": rv}, default=json_int_dttm_ser, diff --git a/superset/datasets/columns/api.py b/superset/datasets/columns/api.py index d04827d42..cd260a423 100644 --- a/superset/datasets/columns/api.py +++ b/superset/datasets/columns/api.py @@ -16,7 +16,7 @@ # under the License. import logging -from flask import g, Response +from flask import Response from flask_appbuilder.api import expose, permission_name, protect, safe from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -91,7 +91,7 @@ class DatasetColumnsRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - DeleteDatasetColumnCommand(g.user, pk, column_id).run() + DeleteDatasetColumnCommand(pk, column_id).run() return self.response(200, message="OK") except DatasetColumnNotFoundError: return self.response_404() diff --git a/superset/datasets/columns/commands/delete.py b/superset/datasets/columns/commands/delete.py index f5914af52..8fb27f938 100644 --- a/superset/datasets/columns/commands/delete.py +++ b/superset/datasets/columns/commands/delete.py @@ -18,8 +18,8 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User +from superset import security_manager from superset.commands.base import BaseCommand from superset.connectors.sqla.models import TableColumn from superset.dao.exceptions import DAODeleteFailedError @@ -30,14 +30,12 @@ from superset.datasets.columns.commands.exceptions import ( ) from superset.datasets.dao import DatasetDAO from superset.exceptions import SupersetSecurityException -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class DeleteDatasetColumnCommand(BaseCommand): - def __init__(self, user: User, dataset_id: int, model_id: int): - self._actor = user + def __init__(self, dataset_id: int, model_id: int): self._dataset_id = dataset_id self._model_id = model_id self._model: Optional[TableColumn] = None @@ -60,6 +58,6 @@ class DeleteDatasetColumnCommand(BaseCommand): raise DatasetColumnNotFoundError() # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DatasetColumnForbiddenError() from ex diff --git a/superset/datasets/commands/bulk_delete.py b/superset/datasets/commands/bulk_delete.py index 13608eda8..643ac784e 100644 --- a/superset/datasets/commands/bulk_delete.py +++ b/superset/datasets/commands/bulk_delete.py @@ -17,8 +17,7 @@ import logging from typing import List, Optional -from flask_appbuilder.security.sqla.models import User - +from superset import security_manager from superset.commands.base import BaseCommand from superset.commands.exceptions import DeleteFailedError from superset.connectors.sqla.models import SqlaTable @@ -29,15 +28,13 @@ from superset.datasets.commands.exceptions import ( ) from superset.datasets.dao import DatasetDAO from superset.exceptions import SupersetSecurityException -from superset.extensions import db, security_manager -from superset.views.base import check_ownership +from superset.extensions import db logger = logging.getLogger(__name__) class BulkDeleteDatasetCommand(BaseCommand): - def __init__(self, user: User, model_ids: List[int]): - self._actor = user + def __init__(self, model_ids: List[int]): self._model_ids = model_ids self._models: Optional[List[SqlaTable]] = None @@ -84,6 +81,6 @@ class BulkDeleteDatasetCommand(BaseCommand): # Check ownership for model in self._models: try: - check_ownership(model) + security_manager.raise_for_ownership(model) except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 4a89b1a81..b638581ab 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -18,7 +18,6 @@ import logging from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from sqlalchemy.exc import SQLAlchemyError @@ -38,8 +37,7 @@ logger = logging.getLogger(__name__) class CreateDatasetCommand(CreateMixin, BaseCommand): - def __init__(self, user: User, data: Dict[str, Any]): - self._actor = user + def __init__(self, data: Dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -89,7 +87,7 @@ class CreateDatasetCommand(CreateMixin, BaseCommand): exceptions.append(TableNotFoundValidationError(table_name)) try: - owners = self.populate_owners(self._actor, owner_ids) + owners = self.populate_owners(owner_ids) self._properties["owners"] = owners except ValidationError as ex: exceptions.append(ex) diff --git a/superset/datasets/commands/delete.py b/superset/datasets/commands/delete.py index a9e5a0ab5..9ab8f41a4 100644 --- a/superset/datasets/commands/delete.py +++ b/superset/datasets/commands/delete.py @@ -18,9 +18,9 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError +from superset import security_manager from superset.commands.base import BaseCommand from superset.connectors.sqla.models import SqlaTable from superset.dao.exceptions import DAODeleteFailedError @@ -31,15 +31,13 @@ from superset.datasets.commands.exceptions import ( ) from superset.datasets.dao import DatasetDAO from superset.exceptions import SupersetSecurityException -from superset.extensions import db, security_manager -from superset.views.base import check_ownership +from superset.extensions import db logger = logging.getLogger(__name__) class DeleteDatasetCommand(BaseCommand): - def __init__(self, user: User, model_id: int): - self._actor = user + def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[SqlaTable] = None @@ -85,6 +83,6 @@ class DeleteDatasetCommand(BaseCommand): raise DatasetNotFoundError() # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex diff --git a/superset/datasets/commands/refresh.py b/superset/datasets/commands/refresh.py index 962ffb410..5277c2777 100644 --- a/superset/datasets/commands/refresh.py +++ b/superset/datasets/commands/refresh.py @@ -18,8 +18,8 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User +from superset import security_manager from superset.commands.base import BaseCommand from superset.connectors.sqla.models import SqlaTable from superset.datasets.commands.exceptions import ( @@ -29,14 +29,12 @@ from superset.datasets.commands.exceptions import ( ) from superset.datasets.dao import DatasetDAO from superset.exceptions import SupersetSecurityException -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class RefreshDatasetCommand(BaseCommand): - def __init__(self, user: User, model_id: int): - self._actor = user + def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[SqlaTable] = None @@ -58,6 +56,6 @@ class RefreshDatasetCommand(BaseCommand): raise DatasetNotFoundError() # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex diff --git a/superset/datasets/commands/samples.py b/superset/datasets/commands/samples.py index 79ac729be..e252cfb62 100644 --- a/superset/datasets/commands/samples.py +++ b/superset/datasets/commands/samples.py @@ -17,8 +17,7 @@ import logging from typing import Any, Dict, Optional -from flask_appbuilder.security.sqla.models import User - +from superset import security_manager from superset.commands.base import BaseCommand from superset.common.chart_data import ChartDataResultType from superset.common.query_context_factory import QueryContextFactory @@ -33,14 +32,12 @@ from superset.datasets.commands.exceptions import ( from superset.datasets.dao import DatasetDAO from superset.exceptions import SupersetSecurityException from superset.utils.core import QueryStatus -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class SamplesDatasetCommand(BaseCommand): - def __init__(self, user: User, model_id: int, force: bool): - self._actor = user + def __init__(self, model_id: int, force: bool): self._model_id = model_id self._force = force self._model: Optional[SqlaTable] = None @@ -78,6 +75,6 @@ class SamplesDatasetCommand(BaseCommand): raise DatasetNotFoundError() # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index 9d448a6c1..e3c908ceb 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -19,9 +19,9 @@ from collections import Counter from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError +from superset import security_manager from superset.commands.base import BaseCommand, UpdateMixin from superset.connectors.sqla.models import SqlaTable from superset.dao.exceptions import DAOUpdateFailedError @@ -41,7 +41,6 @@ from superset.datasets.commands.exceptions import ( ) from superset.datasets.dao import DatasetDAO from superset.exceptions import SupersetSecurityException -from superset.views.base import check_ownership logger = logging.getLogger(__name__) @@ -49,12 +48,10 @@ logger = logging.getLogger(__name__) class UpdateDatasetCommand(UpdateMixin, BaseCommand): def __init__( self, - user: User, model_id: int, data: Dict[str, Any], override_columns: bool = False, ): - self._actor = user self._model_id = model_id self._properties = data.copy() self._model: Optional[SqlaTable] = None @@ -83,7 +80,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): raise DatasetNotFoundError() # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex @@ -99,7 +96,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): exceptions.append(DatabaseChangeValidationError()) # Validate/Populate owner try: - owners = self.populate_owners(self._actor, owner_ids) + owners = self.populate_owners(owner_ids) self._properties["owners"] = owners except ValidationError as ex: exceptions.append(ex) diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index 961bf1fc1..99f5c4d1f 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -17,7 +17,6 @@ import logging from typing import Any, Dict, List, Optional -from flask import current_app from sqlalchemy.exc import SQLAlchemyError from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn @@ -36,14 +35,6 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods model_cls = SqlaTable base_filter = DatasourceFilter - @staticmethod - def get_owner_by_id(owner_id: int) -> Optional[object]: - return ( - db.session.query(current_app.appbuilder.sm.user_model) - .filter_by(id=owner_id) - .one_or_none() - ) - @staticmethod def get_database_by_id(database_id: int) -> Optional[Database]: try: diff --git a/superset/datasets/metrics/api.py b/superset/datasets/metrics/api.py index b55ab9dbb..c0831670d 100644 --- a/superset/datasets/metrics/api.py +++ b/superset/datasets/metrics/api.py @@ -16,7 +16,7 @@ # under the License. import logging -from flask import g, Response +from flask import Response from flask_appbuilder.api import expose, permission_name, protect, safe from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -91,7 +91,7 @@ class DatasetMetricRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - DeleteDatasetMetricCommand(g.user, pk, metric_id).run() + DeleteDatasetMetricCommand(pk, metric_id).run() return self.response(200, message="OK") except DatasetMetricNotFoundError: return self.response_404() diff --git a/superset/datasets/metrics/commands/delete.py b/superset/datasets/metrics/commands/delete.py index cb3f1e0be..d57e7fa35 100644 --- a/superset/datasets/metrics/commands/delete.py +++ b/superset/datasets/metrics/commands/delete.py @@ -18,8 +18,8 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User +from superset import security_manager from superset.commands.base import BaseCommand from superset.connectors.sqla.models import SqlMetric from superset.dao.exceptions import DAODeleteFailedError @@ -30,14 +30,12 @@ from superset.datasets.metrics.commands.exceptions import ( DatasetMetricNotFoundError, ) from superset.exceptions import SupersetSecurityException -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class DeleteDatasetMetricCommand(BaseCommand): - def __init__(self, user: User, dataset_id: int, model_id: int): - self._actor = user + def __init__(self, dataset_id: int, model_id: int): self._dataset_id = dataset_id self._model_id = model_id self._model: Optional[SqlMetric] = None @@ -60,6 +58,6 @@ class DeleteDatasetMetricCommand(BaseCommand): raise DatasetMetricNotFoundError() # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise DatasetMetricForbiddenError() from ex diff --git a/superset/explore/commands/get.py b/superset/explore/commands/get.py index c05ffdb41..80c8744fd 100644 --- a/superset/explore/commands/get.py +++ b/superset/explore/commands/get.py @@ -54,7 +54,6 @@ class GetExploreCommand(BaseCommand, ABC): self, params: CommandParameters, ) -> None: - self._actor = params.actor self._permalink_key = params.permalink_key self._form_data_key = params.form_data_key self._dataset_id = params.dataset_id @@ -66,7 +65,7 @@ class GetExploreCommand(BaseCommand, ABC): initial_form_data = {} if self._permalink_key is not None: - command = GetExplorePermalinkCommand(self._actor, self._permalink_key) + command = GetExplorePermalinkCommand(self._permalink_key) permalink_value = command.run() if not permalink_value: raise ExplorePermalinkGetFailedError() @@ -76,9 +75,7 @@ class GetExploreCommand(BaseCommand, ABC): if url_params: initial_form_data["url_params"] = dict(url_params) elif self._form_data_key: - parameters = FormDataCommandParameters( - actor=self._actor, key=self._form_data_key - ) + parameters = FormDataCommandParameters(key=self._form_data_key) value = GetFormDataCommand(parameters).run() initial_form_data = json.loads(value) if value else {} diff --git a/superset/explore/form_data/api.py b/superset/explore/form_data/api.py index 42b1d11a0..902156d05 100644 --- a/superset/explore/form_data/api.py +++ b/superset/explore/form_data/api.py @@ -16,7 +16,7 @@ # under the License. import logging -from flask import g, request, Response +from flask import request, Response from flask_appbuilder.api import BaseApi, expose, protect, safe from marshmallow import ValidationError @@ -102,7 +102,6 @@ class ExploreFormDataRestApi(BaseApi): item = self.add_model_schema.load(request.json) tab_id = request.args.get("tab_id") args = CommandParameters( - actor=g.user, datasource_id=item["datasource_id"], datasource_type=item["datasource_type"], chart_id=item.get("chart_id"), @@ -173,7 +172,6 @@ class ExploreFormDataRestApi(BaseApi): item = self.edit_model_schema.load(request.json) tab_id = request.args.get("tab_id") args = CommandParameters( - actor=g.user, datasource_id=item["datasource_id"], datasource_type=item["datasource_type"], chart_id=item.get("chart_id"), @@ -233,7 +231,7 @@ class ExploreFormDataRestApi(BaseApi): $ref: '#/components/responses/500' """ try: - args = CommandParameters(actor=g.user, key=key) + args = CommandParameters(key=key) form_data = GetFormDataCommand(args).run() if not form_data: return self.response_404() @@ -285,7 +283,7 @@ class ExploreFormDataRestApi(BaseApi): $ref: '#/components/responses/500' """ try: - args = CommandParameters(actor=g.user, key=key) + args = CommandParameters(key=key) result = DeleteFormDataCommand(args).run() if not result: return self.response_404() diff --git a/superset/explore/form_data/commands/create.py b/superset/explore/form_data/commands/create.py index 5c301a96f..df0250f2f 100644 --- a/superset/explore/form_data/commands/create.py +++ b/superset/explore/form_data/commands/create.py @@ -24,10 +24,10 @@ from superset.explore.form_data.commands.parameters import CommandParameters from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.utils import check_access from superset.extensions import cache_manager -from superset.key_value.utils import get_owner, random_key +from superset.key_value.utils import random_key from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError from superset.temporary_cache.utils import cache_key -from superset.utils.core import DatasourceType +from superset.utils.core import DatasourceType, get_user_id from superset.utils.schema import validate_json logger = logging.getLogger(__name__) @@ -44,9 +44,8 @@ class CreateFormDataCommand(BaseCommand): datasource_type = self._cmd_params.datasource_type chart_id = self._cmd_params.chart_id tab_id = self._cmd_params.tab_id - actor = self._cmd_params.actor form_data = self._cmd_params.form_data - check_access(datasource_id, chart_id, actor, datasource_type) + check_access(datasource_id, chart_id, datasource_type) contextual_key = cache_key( session.get("_id"), tab_id, datasource_id, chart_id, datasource_type ) @@ -55,7 +54,7 @@ class CreateFormDataCommand(BaseCommand): key = random_key() if form_data: state: TemporaryExploreState = { - "owner": get_owner(actor), + "owner": get_user_id(), "datasource_id": datasource_id, "datasource_type": DatasourceType(datasource_type), "chart_id": chart_id, diff --git a/superset/explore/form_data/commands/delete.py b/superset/explore/form_data/commands/delete.py index 598ece3f0..bce13b719 100644 --- a/superset/explore/form_data/commands/delete.py +++ b/superset/explore/form_data/commands/delete.py @@ -26,13 +26,12 @@ from superset.explore.form_data.commands.parameters import CommandParameters from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.utils import check_access from superset.extensions import cache_manager -from superset.key_value.utils import get_owner from superset.temporary_cache.commands.exceptions import ( TemporaryCacheAccessDeniedError, TemporaryCacheDeleteFailedError, ) from superset.temporary_cache.utils import cache_key -from superset.utils.core import DatasourceType +from superset.utils.core import DatasourceType, get_user_id logger = logging.getLogger(__name__) @@ -43,7 +42,6 @@ class DeleteFormDataCommand(BaseCommand, ABC): def run(self) -> bool: try: - actor = self._cmd_params.actor key = self._cmd_params.key state: TemporaryExploreState = cache_manager.explore_form_data_cache.get( key @@ -52,8 +50,8 @@ class DeleteFormDataCommand(BaseCommand, ABC): datasource_id: int = state["datasource_id"] chart_id: Optional[int] = state["chart_id"] datasource_type = DatasourceType(state["datasource_type"]) - check_access(datasource_id, chart_id, actor, datasource_type) - if state["owner"] != get_owner(actor): + check_access(datasource_id, chart_id, datasource_type) + if state["owner"] != get_user_id(): raise TemporaryCacheAccessDeniedError() tab_id = self._cmd_params.tab_id contextual_key = cache_key( diff --git a/superset/explore/form_data/commands/get.py b/superset/explore/form_data/commands/get.py index 982c8e3b4..53fd6ea6a 100644 --- a/superset/explore/form_data/commands/get.py +++ b/superset/explore/form_data/commands/get.py @@ -40,7 +40,6 @@ class GetFormDataCommand(BaseCommand, ABC): def run(self) -> Optional[str]: try: - actor = self._cmd_params.actor key = self._cmd_params.key state: TemporaryExploreState = cache_manager.explore_form_data_cache.get( key @@ -49,7 +48,6 @@ class GetFormDataCommand(BaseCommand, ABC): check_access( state["datasource_id"], state["chart_id"], - actor, DatasourceType(state["datasource_type"]), ) if self._refresh_timeout: diff --git a/superset/explore/form_data/commands/parameters.py b/superset/explore/form_data/commands/parameters.py index fec06a581..c6c1574c8 100644 --- a/superset/explore/form_data/commands/parameters.py +++ b/superset/explore/form_data/commands/parameters.py @@ -17,14 +17,11 @@ from dataclasses import dataclass from typing import Optional -from flask_appbuilder.security.sqla.models import User - from superset.utils.core import DatasourceType @dataclass class CommandParameters: - actor: User datasource_type: DatasourceType = DatasourceType.TABLE datasource_id: int = 0 chart_id: int = 0 diff --git a/superset/explore/form_data/commands/update.py b/superset/explore/form_data/commands/update.py index f48d8e85e..ace57350c 100644 --- a/superset/explore/form_data/commands/update.py +++ b/superset/explore/form_data/commands/update.py @@ -26,13 +26,13 @@ from superset.explore.form_data.commands.parameters import CommandParameters from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.utils import check_access from superset.extensions import cache_manager -from superset.key_value.utils import get_owner, random_key +from superset.key_value.utils import random_key from superset.temporary_cache.commands.exceptions import ( TemporaryCacheAccessDeniedError, TemporaryCacheUpdateFailedError, ) from superset.temporary_cache.utils import cache_key -from superset.utils.core import DatasourceType +from superset.utils.core import DatasourceType, get_user_id from superset.utils.schema import validate_json logger = logging.getLogger(__name__) @@ -51,14 +51,13 @@ class UpdateFormDataCommand(BaseCommand, ABC): datasource_id = self._cmd_params.datasource_id chart_id = self._cmd_params.chart_id datasource_type = self._cmd_params.datasource_type - actor = self._cmd_params.actor key = self._cmd_params.key form_data = self._cmd_params.form_data - check_access(datasource_id, chart_id, actor, datasource_type) + check_access(datasource_id, chart_id, datasource_type) state: TemporaryExploreState = cache_manager.explore_form_data_cache.get( key ) - owner = get_owner(actor) + owner = get_user_id() if state and form_data: if state["owner"] != owner: raise TemporaryCacheAccessDeniedError() diff --git a/superset/explore/form_data/commands/utils.py b/superset/explore/form_data/commands/utils.py index 792745717..e4a843dc6 100644 --- a/superset/explore/form_data/commands/utils.py +++ b/superset/explore/form_data/commands/utils.py @@ -16,8 +16,6 @@ # under the License. from typing import Optional -from flask_appbuilder.security.sqla.models import User - from superset.charts.commands.exceptions import ( ChartAccessDeniedError, ChartNotFoundError, @@ -37,11 +35,10 @@ from superset.utils.core import DatasourceType def check_access( datasource_id: int, chart_id: Optional[int], - actor: User, datasource_type: DatasourceType, ) -> None: try: - explore_check_access(datasource_id, chart_id, actor, datasource_type) + explore_check_access(datasource_id, chart_id, datasource_type) except (ChartNotFoundError, DatasetNotFoundError) as ex: raise TemporaryCacheResourceNotFoundError from ex except (ChartAccessDeniedError, DatasetAccessDeniedError) as ex: diff --git a/superset/explore/permalink/api.py b/superset/explore/permalink/api.py index 1d78e4354..7e2813dec 100644 --- a/superset/explore/permalink/api.py +++ b/superset/explore/permalink/api.py @@ -16,7 +16,7 @@ # under the License. import logging -from flask import g, request, Response +from flask import request, Response from flask_appbuilder.api import BaseApi, expose, protect, safe from marshmallow import ValidationError @@ -100,7 +100,7 @@ class ExplorePermalinkRestApi(BaseApi): """ try: state = self.add_model_schema.load(request.json) - key = CreateExplorePermalinkCommand(actor=g.user, state=state).run() + key = CreateExplorePermalinkCommand(state=state).run() http_origin = request.headers.environ.get("HTTP_ORIGIN") url = f"{http_origin}/superset/explore/p/{key}/" return self.response(201, key=key, url=url) @@ -156,7 +156,7 @@ class ExplorePermalinkRestApi(BaseApi): $ref: '#/components/responses/500' """ try: - value = GetExplorePermalinkCommand(actor=g.user, key=key).run() + value = GetExplorePermalinkCommand(key=key).run() if not value: return self.response_404() return self.response(200, **value) diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index 7bd6365d8..77ce04c4e 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -17,7 +17,6 @@ import logging from typing import Any, Dict, Optional -from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand @@ -31,8 +30,7 @@ logger = logging.getLogger(__name__) class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand): - def __init__(self, actor: User, state: Dict[str, Any]): - self.actor = actor + def __init__(self, state: Dict[str, Any]): self.chart_id: Optional[int] = state["formData"].get("slice_id") self.datasource: str = state["formData"]["datasource"] self.state = state @@ -43,9 +41,7 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand): d_id, d_type = self.datasource.split("__") datasource_id = int(d_id) datasource_type = DatasourceType(d_type) - check_chart_access( - datasource_id, self.chart_id, self.actor, datasource_type - ) + check_chart_access(datasource_id, self.chart_id, datasource_type) value = { "chartId": self.chart_id, "datasourceId": datasource_id, @@ -54,7 +50,6 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand): "state": self.state, } command = CreateKeyValueCommand( - actor=self.actor, resource=self.resource, value=value, ) diff --git a/superset/explore/permalink/commands/get.py b/superset/explore/permalink/commands/get.py index ca4fe8c74..3376cab08 100644 --- a/superset/explore/permalink/commands/get.py +++ b/superset/explore/permalink/commands/get.py @@ -17,7 +17,6 @@ import logging from typing import Optional -from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError from superset.datasets.commands.exceptions import DatasetNotFoundError @@ -34,8 +33,7 @@ logger = logging.getLogger(__name__) class GetExplorePermalinkCommand(BaseExplorePermalinkCommand): - def __init__(self, actor: User, key: str): - self.actor = actor + def __init__(self, key: str): self.key = key def run(self) -> Optional[ExplorePermalinkValue]: @@ -55,7 +53,7 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand): datasource_type = DatasourceType( value.get("datasourceType", DatasourceType.TABLE) ) - check_chart_access(datasource_id, chart_id, self.actor, datasource_type) + check_chart_access(datasource_id, chart_id, datasource_type) return value return None except ( diff --git a/superset/explore/utils.py b/superset/explore/utils.py index f0bfd8f0a..a1c329510 100644 --- a/superset/explore/utils.py +++ b/superset/explore/utils.py @@ -16,8 +16,6 @@ # under the License. from typing import Optional -from flask_appbuilder.security.sqla.models import User - from superset import security_manager from superset.charts.commands.exceptions import ( ChartAccessDeniedError, @@ -36,8 +34,6 @@ from superset.datasets.commands.exceptions import ( from superset.datasets.dao import DatasetDAO from superset.queries.dao import QueryDAO from superset.utils.core import DatasourceType -from superset.views.base import is_user_admin -from superset.views.utils import is_owner def check_dataset_access(dataset_id: int) -> Optional[bool]: @@ -80,7 +76,6 @@ def check_datasource_access( def check_access( datasource_id: int, chart_id: Optional[int], - actor: User, datasource_type: DatasourceType, ) -> Optional[bool]: check_datasource_access(datasource_id, datasource_type) @@ -88,11 +83,9 @@ def check_access( return True chart = ChartDAO.find_by_id(chart_id) if chart: - can_access_chart = ( - is_user_admin() - or is_owner(chart, actor) - or security_manager.can_access("can_read", "Chart") - ) + can_access_chart = security_manager.is_owner( + chart + ) or security_manager.can_access("can_read", "Chart") if can_access_chart: return True raise ChartAccessDeniedError() diff --git a/superset/key_value/commands/create.py b/superset/key_value/commands/create.py index 5125ce7b0..d4ab4c5c3 100644 --- a/superset/key_value/commands/create.py +++ b/superset/key_value/commands/create.py @@ -20,7 +20,6 @@ from datetime import datetime from typing import Any, Optional, Union from uuid import UUID -from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError from superset import db @@ -28,23 +27,21 @@ from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.models import KeyValueEntry from superset.key_value.types import Key, KeyValueResource +from superset.utils.core import get_user_id logger = logging.getLogger(__name__) class CreateKeyValueCommand(BaseCommand): - actor: Optional[User] resource: KeyValueResource value: Any key: Optional[Union[int, UUID]] expires_on: Optional[datetime] - # pylint: disable=too-many-arguments def __init__( self, resource: KeyValueResource, value: Any, - actor: Optional[User] = None, key: Optional[Union[int, UUID]] = None, expires_on: Optional[datetime] = None, ): @@ -53,13 +50,11 @@ class CreateKeyValueCommand(BaseCommand): :param resource: the resource (dashboard, chart etc) :param value: the value to persist in the key-value store - :param actor: the user performing the command :param key: id of entry (autogenerated if undefined) :param expires_on: entry expiration time :return: the key associated with the persisted value """ self.resource = resource - self.actor = actor self.value = value self.key = key self.expires_on = expires_on @@ -80,9 +75,7 @@ class CreateKeyValueCommand(BaseCommand): resource=self.resource.value, value=pickle.dumps(self.value), created_on=datetime.now(), - created_by_fk=None - if self.actor is None or self.actor.is_anonymous - else self.actor.id, + created_by_fk=get_user_id(), expires_on=self.expires_on, ) if self.key is not None: diff --git a/superset/key_value/commands/update.py b/superset/key_value/commands/update.py index 48fd8daa8..4078a0dcb 100644 --- a/superset/key_value/commands/update.py +++ b/superset/key_value/commands/update.py @@ -21,7 +21,6 @@ from datetime import datetime from typing import Any, Optional, Union from uuid import UUID -from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError from superset import db @@ -30,24 +29,22 @@ from superset.key_value.exceptions import KeyValueUpdateFailedError from superset.key_value.models import KeyValueEntry from superset.key_value.types import Key, KeyValueResource from superset.key_value.utils import get_filter +from superset.utils.core import get_user_id logger = logging.getLogger(__name__) class UpdateKeyValueCommand(BaseCommand): - actor: Optional[User] resource: KeyValueResource value: Any key: Union[int, UUID] expires_on: Optional[datetime] - # pylint: disable=too-many-argumentsåå def __init__( self, resource: KeyValueResource, key: Union[int, UUID], value: Any, - actor: Optional[User] = None, expires_on: Optional[datetime] = None, ): """ @@ -56,11 +53,9 @@ class UpdateKeyValueCommand(BaseCommand): :param resource: the resource (dashboard, chart etc) :param key: the key to update :param value: the value to persist in the key-value store - :param actor: the user performing the command :param expires_on: entry expiration time :return: the key associated with the updated value """ - self.actor = actor self.resource = resource self.key = key self.value = value @@ -89,9 +84,7 @@ class UpdateKeyValueCommand(BaseCommand): entry.value = pickle.dumps(self.value) entry.expires_on = self.expires_on entry.changed_on = datetime.now() - entry.changed_by_fk = ( - None if self.actor is None or self.actor.is_anonymous else self.actor.id - ) + entry.changed_by_fk = get_user_id() db.session.merge(entry) db.session.commit() return Key(id=entry.id, uuid=entry.uuid) diff --git a/superset/key_value/commands/upsert.py b/superset/key_value/commands/upsert.py index 8fd0bd240..4bb64aa24 100644 --- a/superset/key_value/commands/upsert.py +++ b/superset/key_value/commands/upsert.py @@ -21,7 +21,6 @@ from datetime import datetime from typing import Any, Optional, Union from uuid import UUID -from flask_appbuilder.security.sqla.models import User from sqlalchemy.exc import SQLAlchemyError from superset import db @@ -31,24 +30,22 @@ from superset.key_value.exceptions import KeyValueUpdateFailedError from superset.key_value.models import KeyValueEntry from superset.key_value.types import Key, KeyValueResource from superset.key_value.utils import get_filter +from superset.utils.core import get_user_id logger = logging.getLogger(__name__) class UpsertKeyValueCommand(BaseCommand): - actor: Optional[User] resource: KeyValueResource value: Any key: Union[int, UUID] expires_on: Optional[datetime] - # pylint: disable=too-many-arguments def __init__( self, resource: KeyValueResource, key: Union[int, UUID], value: Any, - actor: Optional[User] = None, expires_on: Optional[datetime] = None, ): """ @@ -58,11 +55,9 @@ class UpsertKeyValueCommand(BaseCommand): :param key: the key to update :param value: the value to persist in the key-value store :param key_type: the type of the key to update - :param actor: the user performing the command :param expires_on: entry expiration time :return: the key associated with the updated value """ - self.actor = actor self.resource = resource self.key = key self.value = value @@ -91,16 +86,13 @@ class UpsertKeyValueCommand(BaseCommand): entry.value = pickle.dumps(self.value) entry.expires_on = self.expires_on entry.changed_on = datetime.now() - entry.changed_by_fk = ( - None if self.actor is None or self.actor.is_anonymous else self.actor.id - ) + entry.changed_by_fk = get_user_id() db.session.merge(entry) db.session.commit() return Key(entry.id, entry.uuid) return CreateKeyValueCommand( resource=self.resource, value=self.value, - actor=self.actor, key=self.key, expires_on=self.expires_on, ).run() diff --git a/superset/key_value/utils.py b/superset/key_value/utils.py index ec0e06216..b2e8e729b 100644 --- a/superset/key_value/utils.py +++ b/superset/key_value/utils.py @@ -18,11 +18,10 @@ from __future__ import annotations from hashlib import md5 from secrets import token_urlsafe -from typing import Optional, Union +from typing import Union from uuid import UUID import hashids -from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as _ from superset.key_value.exceptions import KeyValueParseKeyError @@ -64,7 +63,3 @@ def get_uuid_namespace(seed: str) -> UUID: md5_obj = md5() md5_obj.update(seed.encode("utf-8")) return UUID(md5_obj.hexdigest()) - - -def get_owner(user: User) -> Optional[int]: - return user.id if not user.is_anonymous else None diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 12f705616..6c733ff4e 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -23,7 +23,6 @@ from functools import partial from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union import sqlalchemy as sqla -from flask import g from flask_appbuilder import Model from flask_appbuilder.models.decorators import renders from flask_appbuilder.security.sqla.models import User @@ -47,7 +46,6 @@ from sqlalchemy.sql import join, select from sqlalchemy.sql.elements import BinaryExpression from superset import app, db, is_feature_enabled, security_manager -from superset.common.request_contexed_based import is_user_admin from superset.connectors.base.models import BaseDatasource from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasource.dao import DatasourceDAO @@ -59,6 +57,7 @@ from superset.models.tags import DashboardUpdater from superset.models.user_attributes import UserAttribute from superset.tasks.thumbnails import cache_dashboard_thumbnail from superset.utils import core as utils +from superset.utils.core import get_user_id from superset.utils.decorators import debounce from superset.utils.hashing import md5_sha_from_str from superset.utils.urls import get_url_path @@ -203,15 +202,14 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): @property def filter_sets_lst(self) -> Dict[int, FilterSet]: - if is_user_admin(): + if security_manager.is_admin(): return self._filter_sets - current_user = g.user.id filter_sets_by_owner_type: Dict[str, List[Any]] = {"Dashboard": [], "User": []} for fs in self._filter_sets: filter_sets_by_owner_type[fs.owner_type].append(fs) user_filter_sets = list( filter( - lambda filter_set: filter_set.owner_id == current_user, + lambda filter_set: filter_set.owner_id == get_user_id(), filter_sets_by_owner_type["User"], ) ) @@ -445,11 +443,6 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): qry = session.query(Dashboard).filter(id_or_slug_filter(id_or_slug)) return qry.one_or_none() - def is_actor_owner(self) -> bool: - if g.user is None or g.user.is_anonymous or not g.user.is_authenticated: - return False - return g.user.id in set(map(lambda user: user.id, self.owners)) - def id_or_slug_filter(id_or_slug: str) -> BinaryExpression: if id_or_slug.isdigit(): diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py index 04df2345c..a82a3dd8e 100644 --- a/superset/queries/saved_queries/api.py +++ b/superset/queries/saved_queries/api.py @@ -192,7 +192,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi): """ item_ids = kwargs["rison"] try: - BulkDeleteSavedQueryCommand(g.user, item_ids).run() + BulkDeleteSavedQueryCommand(item_ids).run() return self.response( 200, message=ngettext( diff --git a/superset/queries/saved_queries/commands/bulk_delete.py b/superset/queries/saved_queries/commands/bulk_delete.py index 0d199378c..c96afd31e 100644 --- a/superset/queries/saved_queries/commands/bulk_delete.py +++ b/superset/queries/saved_queries/commands/bulk_delete.py @@ -17,8 +17,6 @@ import logging from typing import List, Optional -from flask_appbuilder.security.sqla.models import User - from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError from superset.models.dashboard import Dashboard @@ -32,8 +30,7 @@ logger = logging.getLogger(__name__) class BulkDeleteSavedQueryCommand(BaseCommand): - def __init__(self, user: User, model_ids: List[int]): - self._actor = user + def __init__(self, model_ids: List[int]): self._model_ids = model_ids self._models: Optional[List[Dashboard]] = None diff --git a/superset/reports/api.py b/superset/reports/api.py index 046fe9059..9f6fb86a7 100644 --- a/superset/reports/api.py +++ b/superset/reports/api.py @@ -17,7 +17,7 @@ import logging from typing import Any, Optional -from flask import g, request, Response +from flask import request, Response from flask_appbuilder.api import expose, permission_name, protect, rison, safe from flask_appbuilder.hooks import before_request from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -266,7 +266,7 @@ class ReportScheduleRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ try: - DeleteReportScheduleCommand(g.user, pk).run() + DeleteReportScheduleCommand(pk).run() return self.response(200, message="OK") except ReportScheduleNotFoundError: return self.response_404() @@ -340,7 +340,7 @@ class ReportScheduleRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - new_model = CreateReportScheduleCommand(g.user, item).run() + new_model = CreateReportScheduleCommand(item).run() return self.response(201, id=new_model.id, result=item) except ReportScheduleNotFoundError as ex: return self.response_400(message=str(ex)) @@ -421,7 +421,7 @@ class ReportScheduleRestApi(BaseSupersetModelRestApi): except ValidationError as error: return self.response_400(message=error.messages) try: - new_model = UpdateReportScheduleCommand(g.user, pk, item).run() + new_model = UpdateReportScheduleCommand(pk, item).run() return self.response(200, id=new_model.id, result=item) except ReportScheduleNotFoundError: return self.response_404() @@ -483,7 +483,7 @@ class ReportScheduleRestApi(BaseSupersetModelRestApi): """ item_ids = kwargs["rison"] try: - BulkDeleteReportScheduleCommand(g.user, item_ids).run() + BulkDeleteReportScheduleCommand(item_ids).run() return self.response( 200, message=ngettext( diff --git a/superset/reports/commands/bulk_delete.py b/superset/reports/commands/bulk_delete.py index 4bff600d2..131a97af2 100644 --- a/superset/reports/commands/bulk_delete.py +++ b/superset/reports/commands/bulk_delete.py @@ -17,8 +17,7 @@ import logging from typing import List, Optional -from flask_appbuilder.security.sqla.models import User - +from superset import security_manager from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError from superset.exceptions import SupersetSecurityException @@ -29,14 +28,12 @@ from superset.reports.commands.exceptions import ( ReportScheduleNotFoundError, ) from superset.reports.dao import ReportScheduleDAO -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class BulkDeleteReportScheduleCommand(BaseCommand): - def __init__(self, user: User, model_ids: List[int]): - self._actor = user + def __init__(self, model_ids: List[int]): self._model_ids = model_ids self._models: Optional[List[ReportSchedule]] = None @@ -58,6 +55,6 @@ class BulkDeleteReportScheduleCommand(BaseCommand): # Check ownership for model in self._models: try: - check_ownership(model) + security_manager.raise_for_ownership(model) except SupersetSecurityException as ex: raise ReportScheduleForbiddenError() from ex diff --git a/superset/reports/commands/create.py b/superset/reports/commands/create.py index 6d9161445..a67aabef9 100644 --- a/superset/reports/commands/create.py +++ b/superset/reports/commands/create.py @@ -19,7 +19,6 @@ import logging from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.commands.base import CreateMixin @@ -42,8 +41,7 @@ logger = logging.getLogger(__name__) class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): - def __init__(self, user: User, data: Dict[str, Any]): - self._actor = user + def __init__(self, data: Dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -63,7 +61,6 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): creation_method = self._properties.get("creation_method") chart_id = self._properties.get("chart") dashboard_id = self._properties.get("dashboard") - user_id = self._actor.id # Validate type is required if not report_type: @@ -99,7 +96,7 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): if ( creation_method != ReportCreationMethod.ALERTS_REPORTS and not ReportScheduleDAO.validate_unique_creation_method( - user_id, dashboard_id, chart_id + dashboard_id, chart_id ) ): raise ReportScheduleCreationMethodUniquenessValidationError() @@ -110,7 +107,7 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): ) try: - owners = self.populate_owners(self._actor, owner_ids) + owners = self.populate_owners(owner_ids) self._properties["owners"] = owners except ValidationError as ex: exceptions.append(ex) diff --git a/superset/reports/commands/delete.py b/superset/reports/commands/delete.py index eef7a56af..4c38f9b7c 100644 --- a/superset/reports/commands/delete.py +++ b/superset/reports/commands/delete.py @@ -18,8 +18,8 @@ import logging from typing import Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User +from superset import security_manager from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError from superset.exceptions import SupersetSecurityException @@ -30,14 +30,12 @@ from superset.reports.commands.exceptions import ( ReportScheduleNotFoundError, ) from superset.reports.dao import ReportScheduleDAO -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class DeleteReportScheduleCommand(BaseCommand): - def __init__(self, user: User, model_id: int): - self._actor = user + def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[ReportSchedule] = None @@ -58,6 +56,6 @@ class DeleteReportScheduleCommand(BaseCommand): # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise ReportScheduleForbiddenError() from ex diff --git a/superset/reports/commands/update.py b/superset/reports/commands/update.py index 201d96186..c43ee47a0 100644 --- a/superset/reports/commands/update.py +++ b/superset/reports/commands/update.py @@ -19,9 +19,9 @@ import logging from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model -from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError +from superset import security_manager from superset.commands.base import UpdateMixin from superset.dao.exceptions import DAOUpdateFailedError from superset.databases.dao import DatabaseDAO @@ -37,14 +37,12 @@ from superset.reports.commands.exceptions import ( ReportScheduleUpdateFailedError, ) from superset.reports.dao import ReportScheduleDAO -from superset.views.base import check_ownership logger = logging.getLogger(__name__) class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand): - def __init__(self, user: User, model_id: int, data: Dict[str, Any]): - self._actor = user + def __init__(self, model_id: int, data: Dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[ReportSchedule] = None @@ -113,7 +111,7 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand): # Check ownership try: - check_ownership(self._model) + security_manager.raise_for_ownership(self._model) except SupersetSecurityException as ex: raise ReportScheduleForbiddenError() from ex @@ -121,7 +119,7 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand): if owner_ids is None: owner_ids = [owner.id for owner in self._model.owners] try: - owners = self.populate_owners(self._actor, owner_ids) + owners = self.populate_owners(owner_ids) self._properties["owners"] = owners except ValidationError as ex: exceptions.append(ex) diff --git a/superset/reports/dao.py b/superset/reports/dao.py index 312710fe0..21b1473f3 100644 --- a/superset/reports/dao.py +++ b/superset/reports/dao.py @@ -33,6 +33,7 @@ from superset.models.reports import ( ReportScheduleType, ReportState, ) +from superset.utils.core import get_user_id logger = logging.getLogger(__name__) @@ -116,14 +117,14 @@ class ReportScheduleDAO(BaseDAO): @staticmethod def validate_unique_creation_method( - user_id: int, dashboard_id: Optional[int] = None, chart_id: Optional[int] = None + dashboard_id: Optional[int] = None, chart_id: Optional[int] = None ) -> bool: """ Validate if the user already has a chart or dashboard with a report attached form the self subscribe reports """ - query = db.session.query(ReportSchedule).filter_by(created_by_fk=user_id) + query = db.session.query(ReportSchedule).filter_by(created_by_fk=get_user_id()) if dashboard_id is not None: query = query.filter(ReportSchedule.dashboard_id == dashboard_id) diff --git a/superset/security/manager.py b/superset/security/manager.py index 78c83f72e..c1905c63f 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1093,7 +1093,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods from superset.connectors.sqla.models import SqlaTable from superset.extensions import feature_flag_manager from superset.sql_parse import Table - from superset.views.utils import is_owner if database and table or query: if query: @@ -1126,7 +1125,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods for datasource_ in datasources: if self.can_access( "datasource_access", datasource_.perm - ) or is_owner(datasource_, getattr(g, "user", None)): + ) or self.is_owner(datasource_): break else: denied.add(table_) @@ -1152,7 +1151,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if not ( self.can_access_schema(datasource) or self.can_access("datasource_access", datasource.perm or "") - or is_owner(datasource, getattr(g, "user", None)) + or self.is_owner(datasource) or ( should_check_dashboard_access and self.can_access_based_on_dashboard(datasource) @@ -1327,8 +1326,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # pylint: disable=import-outside-toplevel from superset import is_feature_enabled from superset.dashboards.commands.exceptions import DashboardAccessDeniedError - from superset.views.base import is_user_admin - from superset.views.utils import is_owner def has_rbac_access() -> bool: return (not is_feature_enabled("DASHBOARD_RBAC")) or any( @@ -1341,8 +1338,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods can_access = self.has_guest_access(dashboard) else: can_access = ( - is_user_admin() - or is_owner(dashboard, g.user) + self.is_admin() + or self.is_owner(dashboard) or (dashboard.published and has_rbac_access()) or (not dashboard.published and not dashboard.roles) ) @@ -1520,3 +1517,69 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods if str(resource["id"]) == str(dashboard.embedded[0].uuid): return True return False + + def raise_for_ownership(self, resource: Model) -> None: + """ + Raise an exception if the user does not own the resource. + + Note admins are deemed owners of all resources. + + :param resource: The dashboard, dataste, chart, etc. resource + :raises SupersetSecurityException: If the current user is not an owner + """ + + # pylint: disable=import-outside-toplevel + from superset import db + + if self.is_admin(): + return + + # Set of wners that works across ORM models. + owners: List[User] = [] + + orig_resource = db.session.query(resource.__class__).get(resource.id) + + if orig_resource: + if hasattr(resource, "owners"): + owners += orig_resource.owners + + if hasattr(resource, "owner"): + owners.append(orig_resource.owner) + + if hasattr(resource, "created_by"): + owners.append(orig_resource.created_by) + + if g.user.is_anonymous or g.user not in owners: + raise SupersetSecurityException( + SupersetError( + error_type=SupersetErrorType.MISSING_OWNERSHIP_ERROR, + message=f"You don't have the rights to alter [{resource}]", + level=ErrorLevel.ERROR, + ) + ) + + def is_owner(self, resource: Model) -> bool: + """ + Returns True if the current user is an owner of the resource, False otherwise. + + :param resource: The dashboard, dataste, chart, etc. resource + :returns: Whethe the current user is an owner of the resource + """ + + try: + self.raise_for_ownership(resource) + except SupersetSecurityException: + return False + + return True + + def is_admin(self) -> bool: + """ + Returns True if the current user is an admin user, False otherwise. + + :returns: Whehther the current user is an admin user + """ + + return current_app.config["AUTH_ROLE_ADMIN"] in [ + role.name for role in self.get_user_roles() + ] diff --git a/superset/temporary_cache/api.py b/superset/temporary_cache/api.py index bdbdda302..a2d2e287d 100644 --- a/superset/temporary_cache/api.py +++ b/superset/temporary_cache/api.py @@ -20,7 +20,7 @@ from typing import Any from apispec import APISpec from apispec.exceptions import DuplicateComponentNameError -from flask import g, request, Response +from flask import request, Response from flask_appbuilder.api import BaseApi from marshmallow import ValidationError @@ -70,9 +70,7 @@ class TemporaryCacheRestApi(BaseApi, ABC): try: item = self.add_model_schema.load(request.json) tab_id = request.args.get("tab_id") - args = CommandParameters( - actor=g.user, resource_id=pk, value=item["value"], tab_id=tab_id - ) + args = CommandParameters(resource_id=pk, value=item["value"], tab_id=tab_id) key = self.get_create_command()(args).run() return self.response(201, key=key) except ValidationError as ex: @@ -88,7 +86,6 @@ class TemporaryCacheRestApi(BaseApi, ABC): item = self.edit_model_schema.load(request.json) tab_id = request.args.get("tab_id") args = CommandParameters( - actor=g.user, resource_id=pk, key=key, value=item["value"], @@ -105,7 +102,7 @@ class TemporaryCacheRestApi(BaseApi, ABC): def get(self, pk: int, key: str) -> Response: try: - args = CommandParameters(actor=g.user, resource_id=pk, key=key) + args = CommandParameters(resource_id=pk, key=key) value = self.get_get_command()(args).run() if not value: return self.response_404() @@ -117,7 +114,7 @@ class TemporaryCacheRestApi(BaseApi, ABC): def delete(self, pk: int, key: str) -> Response: try: - args = CommandParameters(actor=g.user, resource_id=pk, key=key) + args = CommandParameters(resource_id=pk, key=key) result = self.get_delete_command()(args).run() if not result: return self.response_404() diff --git a/superset/temporary_cache/commands/parameters.py b/superset/temporary_cache/commands/parameters.py index 4d98167c3..74b9c1c63 100644 --- a/superset/temporary_cache/commands/parameters.py +++ b/superset/temporary_cache/commands/parameters.py @@ -17,12 +17,9 @@ from dataclasses import dataclass from typing import Optional -from flask_appbuilder.security.sqla.models import User - @dataclass class CommandParameters: - actor: User resource_id: int tab_id: Optional[int] = None key: Optional[str] = None diff --git a/superset/views/access_requests.py b/superset/views/access_requests.py index e60662e36..063ef5e0b 100644 --- a/superset/views/access_requests.py +++ b/superset/views/access_requests.py @@ -25,9 +25,10 @@ from superset.views.base import DeleteMixin, SupersetModelView from superset.views.core import DAR -class AccessRequestsModelView( - SupersetModelView, DeleteMixin -): # pylint: disable=too-many-ancestors +class AccessRequestsModelView( # pylint: disable=too-many-ancestors + SupersetModelView, + DeleteMixin, +): datamodel = SQLAInterface(DAR) include_route_methods = RouteMethod.CRUD_SET list_columns = [ diff --git a/superset/views/annotations.py b/superset/views/annotations.py index 69718f5f5..b9ef65be0 100644 --- a/superset/views/annotations.py +++ b/superset/views/annotations.py @@ -47,9 +47,10 @@ class StartEndDttmValidator: # pylint: disable=too-few-public-methods ) -class AnnotationModelView( - SupersetModelView, CompactCRUDMixin -): # pylint: disable=too-many-ancestors +class AnnotationModelView( # pylint: disable=too-many-ancestors + SupersetModelView, + CompactCRUDMixin, +): datamodel = SQLAInterface(Annotation) include_route_methods = RouteMethod.CRUD_SET | {"annotation"} diff --git a/superset/views/base.py b/superset/views/base.py index 3a0e429cc..26e22e698 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -38,7 +38,6 @@ from flask_appbuilder import BaseView, Model, ModelView from flask_appbuilder.actions import action from flask_appbuilder.forms import DynamicForm from flask_appbuilder.models.sqla.filters import BaseFilter -from flask_appbuilder.security.sqla.models import User from flask_appbuilder.widgets import ListWidget from flask_babel import get_locale, gettext as __, lazy_gettext as _ from flask_jwt_extended.exceptions import NoAuthorizationError @@ -270,11 +269,6 @@ def create_table_permissions(table: models.SqlaTable) -> None: security_manager.add_permission_view_menu("schema_access", table.schema_perm) -def is_user_admin() -> bool: - user_roles = [role.name.lower() for role in list(security_manager.get_user_roles())] - return "admin" in user_roles - - class BaseSupersetView(BaseView): @staticmethod def json_response(obj: Any, status: int = 200) -> FlaskResponse: @@ -644,53 +638,6 @@ class CsvResponse(Response): default_mimetype = "text/csv" -def check_ownership(obj: Any, raise_if_false: bool = True) -> bool: - """Meant to be used in `pre_update` hooks on models to enforce ownership - - Admin have all access, and other users need to be referenced on either - the created_by field that comes with the ``AuditMixin``, or in a field - named ``owners`` which is expected to be a one-to-many with the User - model. It is meant to be used in the ModelView's pre_update hook in - which raising will abort the update. - """ - if not obj: - return False - - security_exception = SupersetSecurityException( - SupersetError( - error_type=SupersetErrorType.MISSING_OWNERSHIP_ERROR, - message="You don't have the rights to alter [{}]".format(obj), - level=ErrorLevel.ERROR, - ) - ) - - if g.user.is_anonymous: - if raise_if_false: - raise security_exception - return False - if is_user_admin(): - return True - scoped_session = db.create_scoped_session() - orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first() - - # Making a list of owners that works across ORM models - owners: List[User] = [] - if hasattr(orig_obj, "owners"): - owners += orig_obj.owners - if hasattr(orig_obj, "owner"): - owners += [orig_obj.owner] - if hasattr(orig_obj, "created_by"): - owners += [orig_obj.created_by] - - owner_names = [o.username for o in owners if o] - - if g.user and hasattr(g.user, "username") and g.user.username in owner_names: - return True - if raise_if_false: - raise security_exception - return False - - def bind_field( _: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any] ) -> Field: diff --git a/superset/views/chart/views.py b/superset/views/chart/views.py index 60def4868..4d43e797d 100644 --- a/superset/views/chart/views.py +++ b/superset/views/chart/views.py @@ -21,16 +21,12 @@ from flask_appbuilder import expose, has_access from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import lazy_gettext as _ +from superset import security_manager from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.models.slice import Slice from superset.superset_typing import FlaskResponse from superset.utils import core as utils -from superset.views.base import ( - check_ownership, - common_bootstrap_payload, - DeleteMixin, - SupersetModelView, -) +from superset.views.base import common_bootstrap_payload, DeleteMixin, SupersetModelView from superset.views.chart.mixin import SliceMixin from superset.views.utils import bootstrap_user_data @@ -53,10 +49,10 @@ class SliceModelView( def pre_update(self, item: "SliceModelView") -> None: utils.validate_json(item.params) - check_ownership(item) + security_manager.raise_for_ownership(item) def pre_delete(self, item: "SliceModelView") -> None: - check_ownership(item) + security_manager.raise_for_ownership(item) @expose("/add", methods=["GET", "POST"]) @has_access diff --git a/superset/views/core.py b/superset/views/core.py index 79517f347..298535011 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -140,7 +140,6 @@ from superset.utils.decorators import check_dashboard_access from superset.views.base import ( api, BaseSupersetView, - check_ownership, common_bootstrap_payload, create_table_permissions, CsvResponse, @@ -164,7 +163,6 @@ from superset.views.utils import ( get_datasource_info, get_form_data, get_viz, - is_owner, sanitize_datasource_data, ) from superset.viz import BaseViz @@ -368,8 +366,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods return json_error_response(err) # check if you can approve - if security_manager.can_access_all_datasources() or check_ownership( - datasource, raise_if_false=False + if security_manager.can_access_all_datasources() or security_manager.is_owner( + datasource ): # can by done by admin only if role_to_grant: @@ -758,7 +756,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods form_data_key = request.args.get("form_data_key") if key is not None: - command = GetExplorePermalinkCommand(g.user, key) + command = GetExplorePermalinkCommand(key) try: permalink_value = command.run() if permalink_value: @@ -775,7 +773,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods flash(__("Error: %(msg)s", msg=ex.message), "danger") return redirect("/chart/list/") elif form_data_key: - parameters = CommandParameters(actor=g.user, key=form_data_key) + parameters = CommandParameters(key=form_data_key) value = GetFormDataCommand(parameters).run() initial_form_data = json.loads(value) if value else {} @@ -857,7 +855,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods # slc perms slice_add_perm = security_manager.can_access("can_write", "Chart") - slice_overwrite_perm = is_owner(slc, g.user) if slc else False + slice_overwrite_perm = security_manager.is_owner(slc) if slc else False slice_download_perm = security_manager.can_access("can_csv", "Superset") form_data["datasource"] = str(datasource_id) + "__" + cast(str, datasource_type) @@ -1050,7 +1048,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods .one(), ) # check edit dashboard permissions - dash_overwrite_perm = check_ownership(dash, raise_if_false=False) + dash_overwrite_perm = security_manager.is_owner(dash) if not dash_overwrite_perm: return json_error_response( _("You don't have the rights to ") @@ -1297,7 +1295,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods """Save a dashboard's metadata""" session = db.session() dash = session.query(Dashboard).get(dashboard_id) - check_ownership(dash, raise_if_false=True) + security_manager.raise_for_ownership(dash) data = json.loads(request.form["data"]) # client-side send back last_modified_time which was set when # the dashboard was open. it was use to avoid mid-air collision. @@ -1340,7 +1338,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods data = json.loads(request.form["data"]) session = db.session() dash = session.query(Dashboard).get(dashboard_id) - check_ownership(dash, raise_if_false=True) + security_manager.raise_for_ownership(dash) new_slices = session.query(Slice).filter(Slice.id.in_(data["slice_ids"])) dash.slices += new_slices session.merge(dash) @@ -1664,7 +1662,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse: """List of slices a user owns, created, modified or faved""" if not user_id: - user_id = cast(int, g.user.id) + user_id = cast(int, get_user_id()) error_obj = self.get_user_activity_access_error(user_id) if error_obj: return error_obj @@ -1717,7 +1715,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse: """List of slices created by this user""" if not user_id: - user_id = cast(int, g.user.id) + user_id = cast(int, get_user_id()) error_obj = self.get_user_activity_access_error(user_id) if error_obj: return error_obj @@ -1748,7 +1746,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse: """Favorite slices for a user""" if user_id is None: - user_id = g.user.id + user_id = cast(int, get_user_id()) error_obj = self.get_user_activity_access_error(user_id) if error_obj: return error_obj @@ -1957,8 +1955,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods f"/superset/request_access/?dashboard_id={dashboard.id}" ) - dash_edit_perm = check_ownership( - dashboard, raise_if_false=False + dash_edit_perm = security_manager.is_owner( + dashboard ) and security_manager.can_access("can_save_dash", "Superset") edit_mode = ( request.args.get(utils.ReservedUrlParameters.EDIT_MODE.value) == "true" @@ -1994,7 +1992,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods key: str, ) -> FlaskResponse: try: - value = GetDashboardPermalinkCommand(g.user, key).run() + value = GetDashboardPermalinkCommand(key).run() except DashboardPermalinkGetFailedError as ex: flash(__("Error: %(msg)s", msg=ex.message), "danger") return redirect("/dashboard/list/") diff --git a/superset/views/css_templates.py b/superset/views/css_templates.py index 597f9efbd..2041eaa94 100644 --- a/superset/views/css_templates.py +++ b/superset/views/css_templates.py @@ -25,9 +25,10 @@ from superset.superset_typing import FlaskResponse from superset.views.base import DeleteMixin, SupersetModelView -class CssTemplateModelView( - SupersetModelView, DeleteMixin -): # pylint: disable=too-many-ancestors +class CssTemplateModelView( # pylint: disable=too-many-ancestors + SupersetModelView, + DeleteMixin, +): datamodel = SQLAInterface(models.CssTemplate) include_route_methods = RouteMethod.CRUD_SET diff --git a/superset/views/dashboard/mixin.py b/superset/views/dashboard/mixin.py index 77748fdc3..43e4df1c0 100644 --- a/superset/views/dashboard/mixin.py +++ b/superset/views/dashboard/mixin.py @@ -16,8 +16,8 @@ # under the License. from flask_babel import lazy_gettext as _ -from ...dashboards.filters import DashboardAccessFilter -from ..base import check_ownership +from superset import security_manager +from superset.dashboards.filters import DashboardAccessFilter class DashboardMixin: # pylint: disable=too-few-public-methods @@ -90,4 +90,4 @@ class DashboardMixin: # pylint: disable=too-few-public-methods } def pre_delete(self, item: "DashboardMixin") -> None: # pylint: disable=no-self-use - check_ownership(item) + security_manager.raise_for_ownership(item) diff --git a/superset/views/dashboard/views.py b/superset/views/dashboard/views.py index 256bb4c95..8d562fefb 100644 --- a/superset/views/dashboard/views.py +++ b/superset/views/dashboard/views.py @@ -33,7 +33,6 @@ from superset.superset_typing import FlaskResponse from superset.utils import core as utils from superset.views.base import ( BaseSupersetView, - check_ownership, common_bootstrap_payload, DeleteMixin, generate_download_headers, @@ -97,12 +96,11 @@ class DashboardModelView( item.owners.append(g.user) utils.validate_json(item.json_metadata) utils.validate_json(item.position_json) - owners = list(item.owners) for slc in item.slices: - slc.owners = list(set(owners) | set(slc.owners)) + slc.owners = list(set(item.owners) | set(slc.owners)) def pre_update(self, item: "DashboardModelView") -> None: - check_ownership(item) + security_manager.raise_for_ownership(item) self.pre_add(item) diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index bf67eddd0..4e43068c6 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -18,7 +18,7 @@ import json from collections import Counter from typing import Any -from flask import g, request +from flask import request from flask_appbuilder import expose from flask_appbuilder.api import rison from flask_appbuilder.security.decorators import has_access_api @@ -27,7 +27,7 @@ from marshmallow import ValidationError from sqlalchemy.exc import NoSuchTableError from sqlalchemy.orm.exc import NoResultFound -from superset import db, event_logger +from superset import db, event_logger, security_manager from superset.commands.utils import populate_owners from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.utils import get_physical_table_metadata @@ -37,14 +37,12 @@ from superset.datasets.commands.exceptions import ( ) from superset.datasource.dao import DatasourceDAO from superset.exceptions import SupersetException, SupersetSecurityException -from superset.extensions import security_manager from superset.models.core import Database from superset.superset_typing import FlaskResponse from superset.utils.core import DatasourceType from superset.views.base import ( api, BaseSupersetView, - check_ownership, handle_api_exception, json_error_response, ) @@ -84,13 +82,12 @@ class Datasource(BaseSupersetView): if "owners" in datasource_dict and orm_datasource.owner_class is not None: # Check ownership try: - check_ownership(orm_datasource) + security_manager.raise_for_ownership(orm_datasource) except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex - user = security_manager.get_user_by_id(g.user.id) datasource_dict["owners"] = populate_owners( - user, datasource_dict["owners"], default_to_user=False + datasource_dict["owners"], default_to_user=False ) duplicates = [ diff --git a/superset/views/log/views.py b/superset/views/log/views.py index 6cc8d2ffd..89623d8ec 100644 --- a/superset/views/log/views.py +++ b/superset/views/log/views.py @@ -26,7 +26,10 @@ from superset.views.base import SupersetModelView from . import LogMixin -class LogModelView(LogMixin, SupersetModelView): # pylint: disable=too-many-ancestors +class LogModelView( # pylint: disable=too-many-ancestors + LogMixin, + SupersetModelView, +): datamodel = SQLAInterface(models.Log) include_route_methods = {RouteMethod.LIST, RouteMethod.SHOW} class_permission_name = "Log" diff --git a/superset/views/sql_lab.py b/superset/views/sql_lab.py index ae8b0aedc..f83c4521e 100644 --- a/superset/views/sql_lab.py +++ b/superset/views/sql_lab.py @@ -36,9 +36,10 @@ from .base import BaseSupersetView, DeleteMixin, json_success, SupersetModelView logger = logging.getLogger(__name__) -class SavedQueryView( - SupersetModelView, DeleteMixin -): # pylint: disable=too-many-ancestors +class SavedQueryView( # pylint: disable=too-many-ancestors + SupersetModelView, + DeleteMixin, +): datamodel = SQLAInterface(SavedQuery) include_route_methods = RouteMethod.CRUD_SET diff --git a/superset/views/utils.py b/superset/views/utils.py index d696b4b74..6b6d5a0fb 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -32,7 +32,6 @@ from sqlalchemy.orm.exc import NoResultFound import superset.models.core as models from superset import app, dataframe, db, result_set, viz from superset.common.db_query_status import QueryStatus -from superset.connectors.sqla.models import SqlaTable from superset.datasource.dao import DatasourceDAO from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( @@ -427,11 +426,6 @@ def is_slice_in_container( return False -def is_owner(obj: Union[Dashboard, Slice, SqlaTable], user: User) -> bool: - """Check if user is owner of the slice""" - return obj and user in obj.owners - - def check_resource_permissions( check_perms: Callable[..., Any], ) -> Callable[..., Any]: diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py index ec205b6a6..214b7cbfe 100644 --- a/tests/integration_tests/charts/commands_tests.py +++ b/tests/integration_tests/charts/commands_tests.py @@ -350,58 +350,60 @@ class TestImportChartsCommand(SupersetTestCase): class TestChartsCreateCommand(SupersetTestCase): - @patch("superset.views.base.g") + @patch("superset.utils.core.g") + @patch("superset.charts.commands.create.g") @patch("superset.security.manager.g") @pytest.mark.usefixtures("load_energy_table_with_slice") - def test_create_v1_response(self, mock_sm_g, mock_g): + def test_create_v1_response(self, mock_sm_g, mock_c_g, mock_u_g): """Test that the create chart command creates a chart""" - actor = security_manager.find_user(username="admin") - mock_g.user = mock_sm_g.user = actor + user = security_manager.find_user(username="admin") + mock_u_g.user = mock_c_g.user = mock_sm_g.user = user chart_data = { "slice_name": "new chart", "description": "new description", - "owners": [actor.id], + "owners": [user.id], "viz_type": "new_viz_type", "params": json.dumps({"viz_type": "new_viz_type"}), "cache_timeout": 1000, "datasource_id": 1, "datasource_type": "table", } - command = CreateChartCommand(actor, chart_data) + command = CreateChartCommand(chart_data) chart = command.run() chart = db.session.query(Slice).get(chart.id) assert chart.viz_type == "new_viz_type" json_params = json.loads(chart.params) assert json_params == {"viz_type": "new_viz_type"} assert chart.slice_name == "new chart" - assert chart.owners == [actor] + assert chart.owners == [user] db.session.delete(chart) db.session.commit() class TestChartsUpdateCommand(SupersetTestCase): - @patch("superset.views.base.g") + @patch("superset.charts.commands.update.g") + @patch("superset.utils.core.g") @patch("superset.security.manager.g") @pytest.mark.usefixtures("load_energy_table_with_slice") - def test_update_v1_response(self, mock_sm_g, mock_g): + def test_update_v1_response(self, mock_sm_g, mock_c_g, mock_u_g): """Test that a chart command updates properties""" pk = db.session.query(Slice).all()[0].id - actor = security_manager.find_user(username="admin") - mock_g.user = mock_sm_g.user = actor + user = security_manager.find_user(username="admin") + mock_u_g.user = mock_c_g.user = mock_sm_g.user = user model_id = pk json_obj = { "description": "test for update", "cache_timeout": None, - "owners": [actor.id], + "owners": [user.id], } - command = UpdateChartCommand(actor, model_id, json_obj) + command = UpdateChartCommand(model_id, json_obj) last_saved_before = db.session.query(Slice).get(pk).last_saved_at command.run() chart = db.session.query(Slice).get(pk) assert chart.last_saved_at != last_saved_before - assert chart.last_saved_by == actor + assert chart.last_saved_by == user - @patch("superset.views.base.g") + @patch("superset.utils.core.g") @patch("superset.security.manager.g") @pytest.mark.usefixtures("load_energy_table_with_slice") def test_query_context_update_command(self, mock_sm_g, mock_g): @@ -415,14 +417,14 @@ class TestChartsUpdateCommand(SupersetTestCase): chart.owners = [admin] db.session.commit() - actor = security_manager.find_user(username="alpha") - mock_g.user = mock_sm_g.user = actor + user = security_manager.find_user(username="alpha") + mock_g.user = mock_sm_g.user = user query_context = json.dumps({"foo": "bar"}) json_obj = { "query_context_generation": True, "query_context": query_context, } - command = UpdateChartCommand(actor, pk, json_obj) + command = UpdateChartCommand(pk, json_obj) command.run() chart = db.session.query(Slice).get(pk) assert chart.query_context == query_context diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 7f9daedea..ed8eb43cc 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -70,10 +70,11 @@ class TestCreateDatabaseCommand(SupersetTestCase): @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) - def test_create_duplicate_error(self, mock_logger): + @mock.patch("superset.utils.core.g") + def test_create_duplicate_error(self, mock_g, mock_logger): example_db = get_example_database() + mock_g.user = security_manager.find_user("admin") command = CreateDatabaseCommand( - security_manager.find_user("admin"), {"database_name": example_db.database_name}, ) with pytest.raises(DatabaseInvalidError) as excinfo: @@ -90,8 +91,10 @@ class TestCreateDatabaseCommand(SupersetTestCase): @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) - def test_multiple_error_logging(self, mock_logger): - command = CreateDatabaseCommand(security_manager.find_user("admin"), {}) + @mock.patch("superset.utils.core.g") + def test_multiple_error_logging(self, mock_g, mock_logger): + mock_g.user = security_manager.find_user("admin") + command = CreateDatabaseCommand({}) with pytest.raises(DatabaseInvalidError) as excinfo: command.run() assert str(excinfo.value) == ("Database parameters are invalid.") @@ -643,15 +646,17 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) - def test_connection_db_exception(self, mock_event_logger, mock_get_sqla_engine): + @mock.patch("superset.utils.core.g") + def test_connection_db_exception( + self, mock_g, mock_event_logger, mock_get_sqla_engine + ): """Test to make sure event_logger is called when an exception is raised""" database = get_example_database() + mock_g.user = security_manager.find_user("admin") mock_get_sqla_engine.side_effect = Exception("An error has occurred!") db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} - command_without_db_name = TestConnectionDatabaseCommand( - security_manager.find_user("admin"), json_payload - ) + command_without_db_name = TestConnectionDatabaseCommand(json_payload) with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: command_without_db_name.run() @@ -664,19 +669,19 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) + @mock.patch("superset.utils.core.g") def test_connection_do_ping_exception( - self, mock_event_logger, mock_get_sqla_engine + self, mock_g, mock_event_logger, mock_get_sqla_engine ): """Test to make sure do_ping exceptions gets captured""" database = get_example_database() + mock_g.user = security_manager.find_user("admin") mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception( "An error has occurred!" ) db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} - command_without_db_name = TestConnectionDatabaseCommand( - security_manager.find_user("admin"), json_payload - ) + command_without_db_name = TestConnectionDatabaseCommand(json_payload) with pytest.raises(DatabaseTestConnectionFailedError) as excinfo: command_without_db_name.run() @@ -689,15 +694,17 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) - def test_connection_do_ping_timeout(self, mock_event_logger, mock_func_timeout): + @mock.patch("superset.utils.core.g") + def test_connection_do_ping_timeout( + self, mock_g, mock_event_logger, mock_func_timeout + ): """Test to make sure do_ping exceptions gets captured""" database = get_example_database() + mock_g.user = security_manager.find_user("admin") mock_func_timeout.side_effect = FunctionTimedOut("Time out") db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} - command_without_db_name = TestConnectionDatabaseCommand( - security_manager.find_user("admin"), json_payload - ) + command_without_db_name = TestConnectionDatabaseCommand(json_payload) with pytest.raises(SupersetTimeoutException) as excinfo: command_without_db_name.run() @@ -711,20 +718,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) + @mock.patch("superset.utils.core.g") def test_connection_superset_security_connection( - self, mock_event_logger, mock_get_sqla_engine + self, mock_g, mock_event_logger, mock_get_sqla_engine ): """Test to make sure event_logger is called when security connection exc is raised""" database = get_example_database() + mock_g.user = security_manager.find_user("admin") mock_get_sqla_engine.side_effect = SupersetSecurityException( SupersetError(error_type=500, message="test", level="info") ) db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} - command_without_db_name = TestConnectionDatabaseCommand( - security_manager.find_user("admin"), json_payload - ) + command_without_db_name = TestConnectionDatabaseCommand(json_payload) with pytest.raises(DatabaseSecurityUnsafeError) as excinfo: command_without_db_name.run() @@ -736,17 +743,19 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) - def test_connection_db_api_exc(self, mock_event_logger, mock_get_sqla_engine): + @mock.patch("superset.utils.core.g") + def test_connection_db_api_exc( + self, mock_g, mock_event_logger, mock_get_sqla_engine + ): """Test to make sure event_logger is called when DBAPIError is raised""" database = get_example_database() + mock_g.user = security_manager.find_user("admin") mock_get_sqla_engine.side_effect = DBAPIError( statement="error", params={}, orig={} ) db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} - command_without_db_name = TestConnectionDatabaseCommand( - security_manager.find_user("admin"), json_payload - ) + command_without_db_name = TestConnectionDatabaseCommand(json_payload) with pytest.raises(DatabaseTestConnectionFailedError) as excinfo: command_without_db_name.run() @@ -778,7 +787,7 @@ def test_validate(DatabaseDAO, is_port_open, is_hostname_valid, app_context): "query": {}, }, } - command = ValidateDatabaseParametersCommand(None, payload) + command = ValidateDatabaseParametersCommand(payload) command.run() @@ -802,7 +811,7 @@ def test_validate_partial(is_port_open, is_hostname_valid, app_context): "query": {}, }, } - command = ValidateDatabaseParametersCommand(None, payload) + command = ValidateDatabaseParametersCommand(payload) with pytest.raises(SupersetErrorsException) as excinfo: command.run() assert excinfo.value.errors == [ @@ -841,7 +850,7 @@ def test_validate_partial_invalid_hostname(is_hostname_valid, app_context): "query": {}, }, } - command = ValidateDatabaseParametersCommand(None, payload) + command = ValidateDatabaseParametersCommand(payload) with pytest.raises(SupersetErrorsException) as excinfo: command.run() assert excinfo.value.errors == [ diff --git a/tests/integration_tests/explore/form_data/commands_tests.py b/tests/integration_tests/explore/form_data/commands_tests.py index 4db48cfa7..18dd8415f 100644 --- a/tests/integration_tests/explore/form_data/commands_tests.py +++ b/tests/integration_tests/explore/form_data/commands_tests.py @@ -110,7 +110,6 @@ class TestCreateFormDataCommand(SupersetTestCase): datasource = f"{dataset.id}__{DatasourceType.TABLE}" args = CommandParameters( - actor=mock_g.user, datasource_id=dataset.id, datasource_type=DatasourceType.TABLE, chart_id=slice.id, @@ -136,7 +135,6 @@ class TestCreateFormDataCommand(SupersetTestCase): datasource = f"{dataset.id}__{DatasourceType.TABLE}" create_args = CommandParameters( - actor=mock_g.user, datasource_id=dataset.id, datasource_type="InvalidType", chart_id=slice.id, @@ -163,7 +161,6 @@ class TestCreateFormDataCommand(SupersetTestCase): datasource = f"{dataset.id}__{DatasourceType.TABLE}" create_args = CommandParameters( - actor=mock_g.user, datasource_id=dataset.id, datasource_type="table", chart_id=slice.id, @@ -189,7 +186,6 @@ class TestCreateFormDataCommand(SupersetTestCase): datasource = f"{dataset.id}__{DatasourceType.TABLE}" create_args = CommandParameters( - actor=mock_g.user, datasource_id=dataset.id, datasource_type=DatasourceType.TABLE, chart_id=slice.id, @@ -198,7 +194,7 @@ class TestCreateFormDataCommand(SupersetTestCase): ) key = CreateFormDataCommand(create_args).run() - key_args = CommandParameters(actor=mock_g.user, key=key) + key_args = CommandParameters(key=key) get_command = GetFormDataCommand(key_args) cache_data = json.loads(get_command.run()) @@ -221,7 +217,6 @@ class TestCreateFormDataCommand(SupersetTestCase): datasource = f"{dataset.id}__{DatasourceType.TABLE}" create_args = CommandParameters( - actor=mock_g.user, datasource_id=dataset.id, datasource_type=DatasourceType.TABLE, chart_id=slice.id, @@ -232,7 +227,6 @@ class TestCreateFormDataCommand(SupersetTestCase): query_datasource = f"{dataset.id}__{DatasourceType.TABLE}" update_args = CommandParameters( - actor=mock_g.user, datasource_id=query.id, datasource_type=DatasourceType.QUERY, chart_id=slice.id, @@ -249,7 +243,7 @@ class TestCreateFormDataCommand(SupersetTestCase): # the updated key returned should be different from the old one assert new_key != key - key_args = CommandParameters(actor=mock_g.user, key=key) + key_args = CommandParameters(key=key) get_command = GetFormDataCommand(key_args) cache_data = json.loads(get_command.run()) @@ -271,7 +265,6 @@ class TestCreateFormDataCommand(SupersetTestCase): datasource = f"{dataset.id}__{DatasourceType.TABLE}" create_args = CommandParameters( - actor=mock_g.user, datasource_id=dataset.id, datasource_type=DatasourceType.TABLE, chart_id=slice.id, @@ -281,7 +274,6 @@ class TestCreateFormDataCommand(SupersetTestCase): key = CreateFormDataCommand(create_args).run() update_args = CommandParameters( - actor=mock_g.user, datasource_id=dataset.id, datasource_type=DatasourceType.TABLE, chart_id=slice.id, @@ -299,7 +291,7 @@ class TestCreateFormDataCommand(SupersetTestCase): # the updated key returned should be the same as the old one assert new_key == key - key_args = CommandParameters(actor=mock_g.user, key=key) + key_args = CommandParameters(key=key) get_command = GetFormDataCommand(key_args) cache_data = json.loads(get_command.run()) @@ -321,7 +313,6 @@ class TestCreateFormDataCommand(SupersetTestCase): datasource = f"{dataset.id}__{DatasourceType.TABLE}" create_args = CommandParameters( - actor=mock_g.user, datasource_id=dataset.id, datasource_type=DatasourceType.TABLE, chart_id=slice.id, @@ -331,7 +322,6 @@ class TestCreateFormDataCommand(SupersetTestCase): key = CreateFormDataCommand(create_args).run() delete_args = CommandParameters( - actor=mock_g.user, key=key, ) @@ -349,7 +339,6 @@ class TestCreateFormDataCommand(SupersetTestCase): } delete_args = CommandParameters( - actor=mock_g.user, key="some_expired_key", ) diff --git a/tests/integration_tests/explore/permalink/commands_tests.py b/tests/integration_tests/explore/permalink/commands_tests.py index 2bb44bb06..63ed02cd7 100644 --- a/tests/integration_tests/explore/permalink/commands_tests.py +++ b/tests/integration_tests/explore/permalink/commands_tests.py @@ -109,7 +109,7 @@ class TestCreatePermalinkDataCommand(SupersetTestCase): datasource = f"{dataset.id}__{DatasourceType.TABLE}" command = CreateExplorePermalinkCommand( - mock_g.user, {"formData": {"datasource": datasource, "slice_id": slice.id}} + {"formData": {"datasource": datasource, "slice_id": slice.id}} ) assert isinstance(command.run(), str) @@ -130,10 +130,10 @@ class TestCreatePermalinkDataCommand(SupersetTestCase): datasource = f"{dataset.id}__{DatasourceType.TABLE}" key = CreateExplorePermalinkCommand( - mock_g.user, {"formData": {"datasource": datasource, "slice_id": slice.id}} + {"formData": {"datasource": datasource, "slice_id": slice.id}} ).run() - get_command = GetExplorePermalinkCommand(mock_g.user, key) + get_command = GetExplorePermalinkCommand(key) cache_data = get_command.run() assert cache_data.get("datasource") == datasource @@ -166,7 +166,7 @@ class TestCreatePermalinkDataCommand(SupersetTestCase): "formData": {"datasource": datasource_string, "slice_id": slice.id} }, } - get_command = GetExplorePermalinkCommand(mock_g.user, "thisisallmocked") + get_command = GetExplorePermalinkCommand("thisisallmocked") cache_data = get_command.run() assert cache_data.get("datasource") == datasource_string diff --git a/tests/integration_tests/key_value/commands/create_test.py b/tests/integration_tests/key_value/commands/create_test.py index 2718aa822..0e789026b 100644 --- a/tests/integration_tests/key_value/commands/create_test.py +++ b/tests/integration_tests/key_value/commands/create_test.py @@ -23,6 +23,7 @@ from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User from superset.extensions import db +from superset.utils.core import override_user from tests.integration_tests.key_value.commands.fixtures import ( admin, ID_KEY, @@ -36,19 +37,23 @@ def test_create_id_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.models import KeyValueEntry - key = CreateKeyValueCommand(actor=admin, resource=RESOURCE, value=VALUE).run() - entry = db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one() - assert pickle.loads(entry.value) == VALUE - assert entry.created_by_fk == admin.id - db.session.delete(entry) - db.session.commit() + with override_user(admin): + key = CreateKeyValueCommand(resource=RESOURCE, value=VALUE).run() + entry = ( + db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one() + ) + assert pickle.loads(entry.value) == VALUE + assert entry.created_by_fk == admin.id + db.session.delete(entry) + db.session.commit() def test_create_uuid_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.models import KeyValueEntry - key = CreateKeyValueCommand(actor=admin, resource=RESOURCE, value=VALUE).run() + with override_user(admin): + key = CreateKeyValueCommand(resource=RESOURCE, value=VALUE).run() entry = ( db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one() ) diff --git a/tests/integration_tests/key_value/commands/update_test.py b/tests/integration_tests/key_value/commands/update_test.py index 3b24ecdf0..8eb03b4ed 100644 --- a/tests/integration_tests/key_value/commands/update_test.py +++ b/tests/integration_tests/key_value/commands/update_test.py @@ -24,6 +24,7 @@ from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User from superset.extensions import db +from superset.utils.core import override_user from tests.integration_tests.key_value.commands.fixtures import ( admin, ID_KEY, @@ -47,12 +48,12 @@ def test_update_id_entry( from superset.key_value.commands.update import UpdateKeyValueCommand from superset.key_value.models import KeyValueEntry - key = UpdateKeyValueCommand( - actor=admin, - resource=RESOURCE, - key=ID_KEY, - value=NEW_VALUE, - ).run() + with override_user(admin): + key = UpdateKeyValueCommand( + resource=RESOURCE, + key=ID_KEY, + value=NEW_VALUE, + ).run() assert key is not None assert key.id == ID_KEY entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one() @@ -68,12 +69,12 @@ def test_update_uuid_entry( from superset.key_value.commands.update import UpdateKeyValueCommand from superset.key_value.models import KeyValueEntry - key = UpdateKeyValueCommand( - actor=admin, - resource=RESOURCE, - key=UUID_KEY, - value=NEW_VALUE, - ).run() + with override_user(admin): + key = UpdateKeyValueCommand( + resource=RESOURCE, + key=UUID_KEY, + value=NEW_VALUE, + ).run() assert key is not None assert key.uuid == UUID_KEY entry = ( @@ -86,10 +87,10 @@ def test_update_uuid_entry( def test_update_missing_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.update import UpdateKeyValueCommand - key = UpdateKeyValueCommand( - actor=admin, - resource=RESOURCE, - key=456, - value=NEW_VALUE, - ).run() + with override_user(admin): + key = UpdateKeyValueCommand( + resource=RESOURCE, + key=456, + value=NEW_VALUE, + ).run() assert key is None diff --git a/tests/integration_tests/key_value/commands/upsert_test.py b/tests/integration_tests/key_value/commands/upsert_test.py index 1970a1fc2..e5cd27e3a 100644 --- a/tests/integration_tests/key_value/commands/upsert_test.py +++ b/tests/integration_tests/key_value/commands/upsert_test.py @@ -24,6 +24,7 @@ from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User from superset.extensions import db +from superset.utils.core import override_user from tests.integration_tests.key_value.commands.fixtures import ( admin, ID_KEY, @@ -47,12 +48,12 @@ def test_upsert_id_entry( from superset.key_value.commands.upsert import UpsertKeyValueCommand from superset.key_value.models import KeyValueEntry - key = UpsertKeyValueCommand( - actor=admin, - resource=RESOURCE, - key=ID_KEY, - value=NEW_VALUE, - ).run() + with override_user(admin): + key = UpsertKeyValueCommand( + resource=RESOURCE, + key=ID_KEY, + value=NEW_VALUE, + ).run() assert key is not None assert key.id == ID_KEY entry = ( @@ -70,12 +71,12 @@ def test_upsert_uuid_entry( from superset.key_value.commands.upsert import UpsertKeyValueCommand from superset.key_value.models import KeyValueEntry - key = UpsertKeyValueCommand( - actor=admin, - resource=RESOURCE, - key=UUID_KEY, - value=NEW_VALUE, - ).run() + with override_user(admin): + key = UpsertKeyValueCommand( + resource=RESOURCE, + key=UUID_KEY, + value=NEW_VALUE, + ).run() assert key is not None assert key.uuid == UUID_KEY entry = ( @@ -89,12 +90,12 @@ def test_upsert_missing_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.commands.upsert import UpsertKeyValueCommand from superset.key_value.models import KeyValueEntry - key = UpsertKeyValueCommand( - actor=admin, - resource=RESOURCE, - key=456, - value=NEW_VALUE, - ).run() + with override_user(admin): + key = UpsertKeyValueCommand( + resource=RESOURCE, + key=456, + value=NEW_VALUE, + ).run() assert key is not None assert key.id == 456 db.session.query(KeyValueEntry).filter_by(id=456).delete() diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 81781d16c..476cf27aa 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -432,8 +432,8 @@ class TestRolePermission(SupersetTestCase): # TODO test slice permission - @patch("superset.security.manager.g") @patch("superset.utils.core.g") + @patch("superset.security.manager.g") def test_schemas_accessible_by_user_admin(self, mock_sm_g, mock_g): mock_g.user = mock_sm_g.user = security_manager.find_user("admin") with self.client.application.test_request_context(): @@ -443,8 +443,8 @@ class TestRolePermission(SupersetTestCase): ) self.assertEqual(schemas, ["1", "2", "3"]) # no changes - @patch("superset.security.manager.g") @patch("superset.utils.core.g") + @patch("superset.security.manager.g") def test_schemas_accessible_by_user_schema_access(self, mock_sm_g, mock_g): # User has schema access to the schema 1 create_schema_perm("[examples].[1]") @@ -458,8 +458,8 @@ class TestRolePermission(SupersetTestCase): self.assertEqual(schemas, ["1"]) delete_schema_perm("[examples].[1]") - @patch("superset.security.manager.g") @patch("superset.utils.core.g") + @patch("superset.security.manager.g") def test_schemas_accessible_by_user_datasource_access(self, mock_sm_g, mock_g): # User has schema access to the datasource temp_schema.wb_health_population in examples DB. mock_g.user = mock_sm_g.user = security_manager.find_user("gamma") @@ -470,8 +470,8 @@ class TestRolePermission(SupersetTestCase): ) self.assertEqual(schemas, ["temp_schema"]) - @patch("superset.security.manager.g") @patch("superset.utils.core.g") + @patch("superset.security.manager.g") def test_schemas_accessible_by_user_datasource_and_schema_access( self, mock_sm_g, mock_g ): @@ -904,9 +904,9 @@ class TestSecurityManager(SupersetTestCase): self.assertFalse(security_manager.can_access_table(database, table)) + @patch("superset.security.SupersetSecurityManager.is_owner") @patch("superset.security.SupersetSecurityManager.can_access") @patch("superset.security.SupersetSecurityManager.can_access_schema") - @patch("superset.views.utils.is_owner") def test_raise_for_access_datasource( self, mock_can_access_schema, mock_can_access, mock_is_owner ): @@ -922,8 +922,8 @@ class TestSecurityManager(SupersetTestCase): with self.assertRaises(SupersetSecurityException): security_manager.raise_for_access(datasource=datasource) + @patch("superset.security.SupersetSecurityManager.is_owner") @patch("superset.security.SupersetSecurityManager.can_access") - @patch("superset.views.utils.is_owner") def test_raise_for_access_query(self, mock_can_access, mock_is_owner): query = Mock( database=get_example_database(), schema="bar", sql="SELECT * FROM foo" @@ -938,10 +938,11 @@ class TestSecurityManager(SupersetTestCase): with self.assertRaises(SupersetSecurityException): security_manager.raise_for_access(query=query) + @patch("superset.security.SupersetSecurityManager.is_owner") @patch("superset.security.SupersetSecurityManager.can_access") @patch("superset.security.SupersetSecurityManager.can_access_schema") def test_raise_for_access_query_context( - self, mock_can_access_schema, mock_can_access + self, mock_can_access_schema, mock_can_access, mock_is_owner ): query_context = Mock(datasource=self.get_datasource_mock()) @@ -950,6 +951,7 @@ class TestSecurityManager(SupersetTestCase): mock_can_access.return_value = False mock_can_access_schema.return_value = False + mock_is_owner.return_value = False with self.assertRaises(SupersetSecurityException): security_manager.raise_for_access(query_context=query_context) @@ -967,9 +969,12 @@ class TestSecurityManager(SupersetTestCase): with self.assertRaises(SupersetSecurityException): security_manager.raise_for_access(database=database, table=table) + @patch("superset.security.SupersetSecurityManager.is_owner") @patch("superset.security.SupersetSecurityManager.can_access") @patch("superset.security.SupersetSecurityManager.can_access_schema") - def test_raise_for_access_viz(self, mock_can_access_schema, mock_can_access): + def test_raise_for_access_viz( + self, mock_can_access_schema, mock_can_access, mock_is_owner + ): test_viz = viz.TableViz(self.get_datasource_mock(), form_data={}) mock_can_access_schema.return_value = True @@ -977,6 +982,7 @@ class TestSecurityManager(SupersetTestCase): mock_can_access.return_value = False mock_can_access_schema.return_value = False + mock_is_owner.return_value = False with self.assertRaises(SupersetSecurityException): security_manager.raise_for_access(viz=test_viz) diff --git a/tests/unit_tests/explore/utils_test.py b/tests/unit_tests/explore/utils_test.py index 64aefbf43..06bde3c4e 100644 --- a/tests/unit_tests/explore/utils_test.py +++ b/tests/unit_tests/explore/utils_test.py @@ -34,13 +34,13 @@ from superset.datasets.commands.exceptions import ( DatasetNotFoundError, ) from superset.exceptions import SupersetSecurityException -from superset.utils.core import DatasourceType +from superset.utils.core import DatasourceType, override_user dataset_find_by_id = "superset.datasets.dao.DatasetDAO.find_by_id" query_find_by_id = "superset.queries.dao.QueryDAO.find_by_id" chart_find_by_id = "superset.charts.dao.ChartDAO.find_by_id" -is_user_admin = "superset.explore.utils.is_user_admin" -is_owner = "superset.explore.utils.is_owner" +is_admin = "superset.security.SupersetSecurityManager.is_admin" +is_owner = "superset.security.SupersetSecurityManager.is_owner" can_access_datasource = ( "superset.security.SupersetSecurityManager.can_access_datasource" ) @@ -55,12 +55,12 @@ def test_unsaved_chart_no_dataset_id(app_context: AppContext) -> None: from superset.explore.utils import check_access as check_chart_access with raises(DatasourceNotFoundValidationError): - check_chart_access( - datasource_id=0, - chart_id=0, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + with override_user(User()): + check_chart_access( + datasource_id=0, + chart_id=0, + datasource_type=DatasourceType.TABLE, + ) def test_unsaved_chart_unknown_dataset_id( @@ -70,12 +70,13 @@ def test_unsaved_chart_unknown_dataset_id( with raises(DatasetNotFoundError): mocker.patch(dataset_find_by_id, return_value=None) - check_chart_access( - datasource_id=1, - chart_id=0, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=0, + datasource_type=DatasourceType.TABLE, + ) def test_unsaved_chart_unknown_query_id( @@ -85,12 +86,13 @@ def test_unsaved_chart_unknown_query_id( with raises(QueryNotFoundValidationError): mocker.patch(query_find_by_id, return_value=None) - check_chart_access( - datasource_id=1, - chart_id=0, - actor=User(), - datasource_type=DatasourceType.QUERY, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=0, + datasource_type=DatasourceType.QUERY, + ) def test_unsaved_chart_unauthorized_dataset( @@ -102,12 +104,13 @@ def test_unsaved_chart_unauthorized_dataset( with raises(DatasetAccessDeniedError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=False) - check_chart_access( - datasource_id=1, - chart_id=0, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=0, + datasource_type=DatasourceType.TABLE, + ) def test_unsaved_chart_authorized_dataset( @@ -118,12 +121,13 @@ def test_unsaved_chart_authorized_dataset( mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) - check_chart_access( - datasource_id=1, - chart_id=0, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=0, + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_unknown_chart_id( @@ -136,12 +140,13 @@ def test_saved_chart_unknown_chart_id( mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) mocker.patch(chart_find_by_id, return_value=None) - check_chart_access( - datasource_id=1, - chart_id=1, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=1, + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_unauthorized_dataset( @@ -153,12 +158,13 @@ def test_saved_chart_unauthorized_dataset( with raises(DatasetAccessDeniedError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=False) - check_chart_access( - datasource_id=1, - chart_id=1, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=1, + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> None: @@ -168,14 +174,15 @@ def test_saved_chart_is_admin(mocker: MockFixture, app_context: AppContext) -> N mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) - mocker.patch(is_user_admin, return_value=True) + mocker.patch(is_admin, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) - check_chart_access( - datasource_id=1, - chart_id=1, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=1, + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> None: @@ -185,15 +192,16 @@ def test_saved_chart_is_owner(mocker: MockFixture, app_context: AppContext) -> N mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) - mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_admin, return_value=False) mocker.patch(is_owner, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) - check_chart_access( - datasource_id=1, - chart_id=1, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=1, + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> None: @@ -203,16 +211,17 @@ def test_saved_chart_has_access(mocker: MockFixture, app_context: AppContext) -> mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) - mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=True) mocker.patch(chart_find_by_id, return_value=Slice()) - check_chart_access( - datasource_id=1, - chart_id=1, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=1, + datasource_type=DatasourceType.TABLE, + ) def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> None: @@ -223,16 +232,17 @@ def test_saved_chart_no_access(mocker: MockFixture, app_context: AppContext) -> with raises(ChartAccessDeniedError): mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) - mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=False) mocker.patch(chart_find_by_id, return_value=Slice()) - check_chart_access( - datasource_id=1, - chart_id=1, - actor=User(), - datasource_type=DatasourceType.TABLE, - ) + + with override_user(User()): + check_chart_access( + datasource_id=1, + chart_id=1, + datasource_type=DatasourceType.TABLE, + ) def test_dataset_has_access(mocker: MockFixture, app_context: AppContext) -> None: @@ -241,7 +251,7 @@ def test_dataset_has_access(mocker: MockFixture, app_context: AppContext) -> Non mocker.patch(dataset_find_by_id, return_value=SqlaTable()) mocker.patch(can_access_datasource, return_value=True) - mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=True) assert ( @@ -259,7 +269,7 @@ def test_query_has_access(mocker: MockFixture, app_context: AppContext) -> None: mocker.patch(query_find_by_id, return_value=Query()) mocker.patch(raise_for_access, return_value=True) - mocker.patch(is_user_admin, return_value=False) + mocker.patch(is_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=True) assert ( @@ -282,10 +292,8 @@ def test_query_no_access(mocker: MockFixture, client, app_context: AppContext) - query_find_by_id, return_value=Query(database=Database(), sql="select * from foo"), ) - table = SqlaTable() - table.owners = [] - mocker.patch(query_datasources_by_name, return_value=[table]) - mocker.patch(is_user_admin, return_value=False) + mocker.patch(query_datasources_by_name, return_value=[SqlaTable()]) + mocker.patch(is_admin, return_value=False) mocker.patch(is_owner, return_value=False) mocker.patch(can_access, return_value=False) check_datasource_access(