diff --git a/superset/app.py b/superset/app.py index 29141c24b..dc706fdc0 100644 --- a/superset/app.py +++ b/superset/app.py @@ -130,6 +130,7 @@ class SupersetAppInitializer: DruidColumnInlineView, Druid, ) + from superset.datasets.api import DatasetRestApi from superset.connectors.sqla.views import ( TableColumnInlineView, SqlMetricInlineView, @@ -182,7 +183,7 @@ class SupersetAppInitializer: appbuilder.add_api(ChartRestApi) appbuilder.add_api(DashboardRestApi) appbuilder.add_api(DatabaseRestApi) - + appbuilder.add_api(DatasetRestApi) # # Setup regular views # diff --git a/superset/commands/__init__.py b/superset/commands/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/superset/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/commands/base.py b/superset/commands/base.py new file mode 100644 index 000000000..9889d6f0c --- /dev/null +++ b/superset/commands/base.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from abc import ABC, abstractmethod + + +class BaseCommand(ABC): + """ + Base class for all Command like Superset Logic objects + """ + + @abstractmethod + def run(self): + """ + Run executes the command. Can raise command exceptions + :return: + """ + pass + + @abstractmethod + def validate(self) -> None: + """ + Validate is normally called by run to validate data. + Will raise exception if validation fails + """ + pass diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py new file mode 100644 index 000000000..83b3e1df4 --- /dev/null +++ b/superset/commands/exceptions.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional + +from marshmallow import ValidationError + + +class CommandException(Exception): + """ Common base class for Command exceptions. """ + + message = "" + + def __init__(self, message: str = "", exception: Optional[Exception] = None): + if message: + self.message = message + self._exception = exception + super().__init__(self.message) + + @property + def exception(self): + return self._exception + + +class CommandInvalidError(CommandException): + """ Common base class for Command Invalid errors. """ + + def __init__(self, message=""): + self._invalid_exceptions = list() + super().__init__(self.message) + + def add(self, exception: ValidationError): + self._invalid_exceptions.append(exception) + + def add_list(self, exceptions: List[ValidationError]): + self._invalid_exceptions.extend(exceptions) + + def normalized_messages(self): + errors = {} + for exception in self._invalid_exceptions: + errors.update(exception.normalized_messages()) + return errors + + +class UpdateFailedError(CommandException): + message = "Command update failed" + + +class CreateFailedError(CommandException): + message = "Command create failed" + + +class DeleteFailedError(CommandException): + message = "Command delete failed" + + +class ForbiddenError(CommandException): + message = "Action is forbidden" diff --git a/superset/datasets/__init__.py b/superset/datasets/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/superset/datasets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/datasets/api.py b/superset/datasets/api.py new file mode 100644 index 000000000..64821dbfd --- /dev/null +++ b/superset/datasets/api.py @@ -0,0 +1,265 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from flask import g, request, Response +from flask_appbuilder.api import expose, protect, safe +from flask_appbuilder.models.sqla.interface import SQLAInterface + +from superset.connectors.sqla.models import SqlaTable +from superset.constants import RouteMethod +from superset.datasets.commands.create import CreateDatasetCommand +from superset.datasets.commands.delete import DeleteDatasetCommand +from superset.datasets.commands.exceptions import ( + DatasetCreateFailedError, + DatasetDeleteFailedError, + DatasetForbiddenError, + DatasetInvalidError, + DatasetNotFoundError, + DatasetUpdateFailedError, +) +from superset.datasets.commands.update import UpdateDatasetCommand +from superset.datasets.schemas import DatasetPostSchema, DatasetPutSchema +from superset.views.base import DatasourceFilter +from superset.views.base_api import BaseSupersetModelRestApi +from superset.views.database.filters import DatabaseFilter + +logger = logging.getLogger(__name__) + + +class DatasetRestApi(BaseSupersetModelRestApi): + datamodel = SQLAInterface(SqlaTable) + base_filters = [["id", DatasourceFilter, lambda: []]] + + resource_name = "dataset" + allow_browser_login = True + + class_permission_name = "TableModelView" + include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {RouteMethod.RELATED} + + list_columns = [ + "database_name", + "changed_by.username", + "changed_on", + "table_name", + "schema", + ] + show_columns = [ + "database.database_name", + "database.id", + "table_name", + "sql", + "filter_select_enabled", + "fetch_values_predicate", + "schema", + "description", + "main_dttm_col", + "offset", + "default_endpoint", + "cache_timeout", + "is_sqllab_view", + "template_params", + "owners.id", + "owners.username", + ] + add_model_schema = DatasetPostSchema() + edit_model_schema = DatasetPutSchema() + add_columns = ["database", "schema", "table_name", "owners"] + edit_columns = [ + "table_name", + "sql", + "filter_select_enabled", + "fetch_values_predicate", + "schema", + "description", + "main_dttm_col", + "offset", + "default_endpoint", + "cache_timeout", + "is_sqllab_view", + "template_params", + "owners", + ] + openapi_spec_tag = "Datasets" + + filter_rel_fields_field = {"owners": "first_name", "database": "database_name"} + filter_rel_fields = {"database": [["id", DatabaseFilter, lambda: []]]} + + @expose("/", methods=["POST"]) + @protect() + @safe + def post(self) -> Response: + """Creates a new Dataset + --- + post: + description: >- + Create a new Dataset + requestBody: + description: Dataset schema + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + responses: + 201: + description: Dataset added + content: + application/json: + schema: + type: object + properties: + id: + type: number + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + if not request.is_json: + return self.response_400(message="Request is not JSON") + item = self.add_model_schema.load(request.json) + # This validates custom Schema with custom validations + if item.errors: + return self.response_400(message=item.errors) + try: + new_model = CreateDatasetCommand(g.user, item.data).run() + return self.response(201, id=new_model.id, result=item.data) + except DatasetInvalidError as e: + return self.response_422(message=e.normalized_messages()) + except DatasetCreateFailedError as e: + logger.error(f"Error creating model {self.__class__.__name__}: {e}") + return self.response_422(message=str(e)) + + @expose("/", methods=["PUT"]) + @protect() + @safe + def put( # pylint: disable=too-many-return-statements, arguments-differ + self, pk: int + ) -> Response: + """Changes a Dataset + --- + put: + description: >- + Changes a Dataset + parameters: + - in: path + schema: + type: integer + name: pk + requestBody: + description: Dataset schema + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + responses: + 200: + description: Dataset changed + content: + application/json: + schema: + type: object + properties: + id: + type: number + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + if not request.is_json: + return self.response_400(message="Request is not JSON") + item = self.edit_model_schema.load(request.json) + # This validates custom Schema with custom validations + if item.errors: + return self.response_400(message=item.errors) + try: + changed_model = UpdateDatasetCommand(g.user, pk, item.data).run() + return self.response(200, id=changed_model.id, result=item.data) + except DatasetNotFoundError: + return self.response_404() + except DatasetForbiddenError: + return self.response_403() + except DatasetInvalidError as e: + return self.response_422(message=e.normalized_messages()) + except DatasetUpdateFailedError as e: + logger.error(f"Error updating model {self.__class__.__name__}: {e}") + return self.response_422(message=str(e)) + + @expose("/", methods=["DELETE"]) + @protect() + @safe + def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ + """Deletes a Dataset + --- + delete: + description: >- + Deletes a Dataset + parameters: + - in: path + schema: + type: integer + name: pk + responses: + 200: + description: Dataset delete + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + DeleteDatasetCommand(g.user, pk).run() + return self.response(200, message="OK") + except DatasetNotFoundError: + return self.response_404() + except DatasetForbiddenError: + return self.response_403() + except DatasetDeleteFailedError as e: + logger.error(f"Error deleting model {self.__class__.__name__}: {e}") + return self.response_422(message=str(e)) diff --git a/superset/datasets/commands/__init__.py b/superset/datasets/commands/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/superset/datasets/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/datasets/commands/base.py b/superset/datasets/commands/base.py new file mode 100644 index 000000000..646dfc328 --- /dev/null +++ b/superset/datasets/commands/base.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import List, Optional + +from flask_appbuilder.security.sqla.models import User + +from superset.datasets.commands.exceptions import OwnersNotFoundValidationError +from superset.datasets.dao import DatasetDAO + + +def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[User]: + """ + Helper function for commands, will fetch all users from owners id's + Can raise ValidationError + + :param user: The current user + :param owners_ids: A List of owners by id's + """ + owners = list() + if not owners_ids: + return [user] + if user.id not in owners_ids: + owners.append(user) + for owner_id in owners_ids: + owner = DatasetDAO.get_owner_by_id(owner_id) + if not owner: + raise OwnersNotFoundValidationError() + owners.append(owner) + return owners diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py new file mode 100644 index 000000000..344770b95 --- /dev/null +++ b/superset/datasets/commands/create.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Dict, List, Optional + +from flask_appbuilder.security.sqla.models import User +from marshmallow import ValidationError + +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CreateFailedError +from superset.datasets.commands.base import populate_owners +from superset.datasets.commands.exceptions import ( + DatabaseNotFoundValidationError, + DatasetCreateFailedError, + DatasetExistsValidationError, + DatasetInvalidError, + TableNotFoundValidationError, +) +from superset.datasets.dao import DatasetDAO + +logger = logging.getLogger(__name__) + + +class CreateDatasetCommand(BaseCommand): + def __init__(self, user: User, data: Dict): + self._actor = user + self._properties = data.copy() + + def run(self): + self.validate() + try: + dataset = DatasetDAO.create(self._properties) + except CreateFailedError as e: + logger.exception(e.exception) + raise DatasetCreateFailedError() + return dataset + + def validate(self) -> None: + exceptions = list() + database_id = self._properties["database"] + table_name = self._properties["table_name"] + schema = self._properties.get("schema", "") + owner_ids: Optional[List[int]] = self._properties.get("owners") + + # Validate uniqueness + if not DatasetDAO.validate_uniqueness(database_id, table_name): + exceptions.append(DatasetExistsValidationError(table_name)) + + # Validate/Populate database + database = DatasetDAO.get_database_by_id(database_id) + if not database: + exceptions.append(DatabaseNotFoundValidationError()) + self._properties["database"] = database + + # Validate table exists on dataset + if database and not DatasetDAO.validate_table_exists( + database, table_name, schema + ): + exceptions.append(TableNotFoundValidationError(table_name)) + + try: + owners = populate_owners(self._actor, owner_ids) + self._properties["owners"] = owners + except ValidationError as e: + exceptions.append(e) + if exceptions: + exception = DatasetInvalidError() + exception.add_list(exceptions) + raise exception diff --git a/superset/datasets/commands/delete.py b/superset/datasets/commands/delete.py new file mode 100644 index 000000000..d61c56a0e --- /dev/null +++ b/superset/datasets/commands/delete.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Optional + +from flask_appbuilder.security.sqla.models import User + +from superset.commands.base import BaseCommand +from superset.commands.exceptions import DeleteFailedError +from superset.connectors.sqla.models import SqlaTable +from superset.datasets.commands.exceptions import ( + DatasetDeleteFailedError, + DatasetForbiddenError, + DatasetNotFoundError, +) +from superset.datasets.dao import DatasetDAO +from superset.exceptions import SupersetSecurityException +from superset.views.base import check_ownership + +logger = logging.getLogger(__name__) + + +class DeleteDatasetCommand(BaseCommand): + def __init__(self, user: User, model_id: int): + self._actor = user + self._model_id = model_id + self._model: Optional[SqlaTable] = None + + def run(self): + self.validate() + try: + dataset = DatasetDAO.delete(self._model) + except DeleteFailedError as e: + logger.exception(e.exception) + raise DatasetDeleteFailedError() + return dataset + + def validate(self) -> None: + # Validate/populate model exists + self._model = DatasetDAO.find_by_id(self._model_id) + if not self._model: + raise DatasetNotFoundError() + # Check ownership + try: + check_ownership(self._model) + except SupersetSecurityException: + raise DatasetForbiddenError() diff --git a/superset/datasets/commands/exceptions.py b/superset/datasets/commands/exceptions.py new file mode 100644 index 000000000..a6d0ed7de --- /dev/null +++ b/superset/datasets/commands/exceptions.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from flask_babel import lazy_gettext as _ +from marshmallow.validate import ValidationError + +from superset.commands.exceptions import ( + CommandException, + CommandInvalidError, + CreateFailedError, + DeleteFailedError, + ForbiddenError, + UpdateFailedError, +) +from superset.views.base import get_datasource_exist_error_msg + + +class DatabaseNotFoundValidationError(ValidationError): + """ + Marshmallow validation error for database does not exist + """ + + def __init__(self): + super().__init__(_("Database does not exist"), field_names=["database"]) + + +class DatabaseChangeValidationError(ValidationError): + """ + Marshmallow validation error database changes are not allowed on update + """ + + def __init__(self): + super().__init__(_("Database not allowed to change"), field_names=["database"]) + + +class DatasetExistsValidationError(ValidationError): + """ + Marshmallow validation error for dataset already exists + """ + + def __init__(self, table_name: str): + super().__init__( + get_datasource_exist_error_msg(table_name), field_names=["table_name"] + ) + + +class TableNotFoundValidationError(ValidationError): + """ + Marshmallow validation error when a table does not exist on the database + """ + + def __init__(self, table_name: str): + super().__init__( + _( + f"Table [{table_name}] could not be found, " + "please double check your " + "database connection, schema, and " + f"table name" + ), + field_names=["table_name"], + ) + + +class OwnersNotFoundValidationError(ValidationError): + def __init__(self): + super().__init__(_("Owners are invalid"), field_names=["owners"]) + + +class DatasetNotFoundError(CommandException): + message = "Dataset not found." + + +class DatasetInvalidError(CommandInvalidError): + message = _("Dataset parameters are invalid.") + + +class DatasetCreateFailedError(CreateFailedError): + message = _("Dataset could not be created.") + + +class DatasetUpdateFailedError(UpdateFailedError): + message = _("Dataset could not be updated.") + + +class DatasetDeleteFailedError(DeleteFailedError): + message = _("Dataset could not be deleted.") + + +class DatasetForbiddenError(ForbiddenError): + message = _("Changing this dataset is forbidden") diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py new file mode 100644 index 000000000..b3deeab2e --- /dev/null +++ b/superset/datasets/commands/update.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Dict, List, Optional + +from flask_appbuilder.security.sqla.models import User +from marshmallow import ValidationError + +from superset.commands.base import BaseCommand +from superset.commands.exceptions import UpdateFailedError +from superset.connectors.sqla.models import SqlaTable +from superset.datasets.commands.base import populate_owners +from superset.datasets.commands.exceptions import ( + DatabaseChangeValidationError, + DatasetExistsValidationError, + DatasetForbiddenError, + DatasetInvalidError, + DatasetNotFoundError, + DatasetUpdateFailedError, +) +from superset.datasets.dao import DatasetDAO +from superset.exceptions import SupersetSecurityException +from superset.views.base import check_ownership + +logger = logging.getLogger(__name__) + + +class UpdateDatasetCommand(BaseCommand): + def __init__(self, user: User, model_id: int, data: Dict): + self._actor = user + self._model_id = model_id + self._properties = data.copy() + self._model: Optional[SqlaTable] = None + + def run(self): + self.validate() + try: + dataset = DatasetDAO.update(self._model, self._properties) + except UpdateFailedError as e: + logger.exception(e.exception) + raise DatasetUpdateFailedError() + return dataset + + def validate(self) -> None: + exceptions = list() + owner_ids: Optional[List[int]] = self._properties.get("owners") + # Validate/populate model exists + self._model = DatasetDAO.find_by_id(self._model_id) + if not self._model: + raise DatasetNotFoundError() + # Check ownership + try: + check_ownership(self._model) + except SupersetSecurityException: + raise DatasetForbiddenError() + + database_id = self._properties.get("database", None) + table_name = self._properties.get("table_name", None) + # Validate uniqueness + if not DatasetDAO.validate_update_uniqueness( + self._model.database_id, self._model_id, table_name + ): + exceptions.append(DatasetExistsValidationError(table_name)) + # Validate/Populate database not allowed to change + if database_id and database_id != self._model: + exceptions.append(DatabaseChangeValidationError()) + # Validate/Populate owner + try: + owners = populate_owners(self._actor, owner_ids) + self._properties["owners"] = owners + except ValidationError as e: + exceptions.append(e) + if exceptions: + exception = DatasetInvalidError() + exception.add_list(exceptions) + raise exception diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py new file mode 100644 index 000000000..7e08ce8c0 --- /dev/null +++ b/superset/datasets/dao.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Dict, Optional + +from flask import current_app +from flask_appbuilder.models.sqla.interface import SQLAInterface +from sqlalchemy.exc import SQLAlchemyError + +from superset.commands.exceptions import ( + CreateFailedError, + DeleteFailedError, + UpdateFailedError, +) +from superset.connectors.sqla.models import SqlaTable +from superset.extensions import db +from superset.models.core import Database +from superset.views.base import DatasourceFilter + +logger = logging.getLogger(__name__) + + +class DatasetDAO: + @staticmethod + def get_owner_by_id(owner_id: int) -> Optional[object]: + return ( + db.session.query(current_app.appbuilder.sm.user_model) + .filter_by(id=owner_id) + .one_or_none() + ) + + @staticmethod + def get_database_by_id(database_id) -> Optional[Database]: + try: + return db.session.query(Database).filter_by(id=database_id).one_or_none() + except SQLAlchemyError as e: # pragma: no cover + logger.error(f"Could not get database by id: {e}") + return None + + @staticmethod + def validate_table_exists(database: Database, table_name: str, schema: str) -> bool: + try: + database.get_table(table_name, schema=schema) + return True + except SQLAlchemyError as e: # pragma: no cover + logger.error(f"Got an error {e} validating table: {table_name}") + return False + + @staticmethod + def validate_uniqueness(database_id: int, name: str) -> bool: + dataset_query = db.session.query(SqlaTable).filter( + SqlaTable.table_name == name, SqlaTable.database_id == database_id + ) + return not db.session.query(dataset_query.exists()).scalar() + + @staticmethod + def validate_update_uniqueness( + database_id: int, dataset_id: int, name: str + ) -> bool: + dataset_query = db.session.query(SqlaTable).filter( + SqlaTable.table_name == name, + SqlaTable.database_id == database_id, + SqlaTable.id != dataset_id, + ) + return not db.session.query(dataset_query.exists()).scalar() + + @staticmethod + def find_by_id(model_id: int) -> SqlaTable: + data_model = SQLAInterface(SqlaTable, db.session) + query = db.session.query(SqlaTable) + query = DatasourceFilter("id", data_model).apply(query, None) + return query.filter_by(id=model_id).one_or_none() + + @staticmethod + def create(properties: Dict, commit=True) -> Optional[SqlaTable]: + model = SqlaTable() + for key, value in properties.items(): + setattr(model, key, value) + try: + db.session.add(model) + if commit: + db.session.commit() + except SQLAlchemyError as e: # pragma: no cover + db.session.rollback() + raise CreateFailedError(exception=e) + return model + + @staticmethod + def update(model: SqlaTable, properties: Dict, commit=True) -> Optional[SqlaTable]: + for key, value in properties.items(): + setattr(model, key, value) + try: + db.session.merge(model) + if commit: + db.session.commit() + except SQLAlchemyError as e: # pragma: no cover + db.session.rollback() + raise UpdateFailedError(exception=e) + return model + + @staticmethod + def delete(model: SqlaTable, commit=True): + try: + db.session.delete(model) + if commit: + db.session.commit() + except SQLAlchemyError as e: # pragma: no cover + logger.error(f"Failed to delete dataset: {e}") + db.session.rollback() + raise DeleteFailedError(exception=e) + return model diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py new file mode 100644 index 000000000..370550da6 --- /dev/null +++ b/superset/datasets/schemas.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from marshmallow import fields, Schema +from marshmallow.validate import Length + + +class DatasetPostSchema(Schema): + database = fields.Integer(required=True) + schema = fields.String(validate=Length(0, 250)) + table_name = fields.String(required=True, allow_none=False, validate=Length(1, 250)) + owners = fields.List(fields.Integer()) + + +class DatasetPutSchema(Schema): + table_name = fields.String(allow_none=True, validate=Length(1, 250)) + sql = fields.String(allow_none=True) + filter_select_enabled = fields.Boolean(allow_none=True) + fetch_values_predicate = fields.String(allow_none=True, validate=Length(0, 1000)) + schema = fields.String(allow_none=True, validate=Length(1, 255)) + description = fields.String(allow_none=True) + main_dttm_col = fields.String(allow_none=True) + offset = fields.Integer(allow_none=True) + default_endpoint = fields.String(allow_none=True) + cache_timeout = fields.Integer(allow_none=True) + is_sqllab_view = fields.Boolean(allow_none=True) + template_params = fields.String(allow_none=True) + owners = fields.List(fields.Integer()) diff --git a/superset/views/base_api.py b/superset/views/base_api.py index ea9286ee3..86a86a6ca 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -21,7 +21,7 @@ from typing import Dict, Tuple from flask import request from flask_appbuilder import ModelRestApi from flask_appbuilder.api import expose, protect, rison, safe -from flask_appbuilder.models.filters import Filters +from flask_appbuilder.models.filters import BaseFilter, Filters from sqlalchemy.exc import SQLAlchemyError from superset.exceptions import SupersetSecurityException @@ -90,7 +90,15 @@ class BaseSupersetModelRestApi(ModelRestApi): Declare the related field field for filtering:: filter_rel_fields_field = { - "": "", "") + "": "") + } + """ # pylint: disable=pointless-string-statement + filter_rel_fields: Dict[str, BaseFilter] = {} + """ + Declare the related field base filter:: + + filter_rel_fields_field = { + "": "") } """ # pylint: disable=pointless-string-statement @@ -117,6 +125,9 @@ class BaseSupersetModelRestApi(ModelRestApi): def _get_related_filter(self, datamodel, column_name: str, value: str) -> Filters: filter_field = self.filter_rel_fields_field.get(column_name) filters = datamodel.get_filters([filter_field]) + base_filters = self.filter_rel_fields.get(column_name) + if base_filters: + filters = filters.add_filter_list(base_filters) if value: filters.rest_add_filters( [{"opr": "sw", "col": filter_field, "value": value}] diff --git a/tests/dataset_api_tests.py b/tests/dataset_api_tests.py new file mode 100644 index 000000000..3f765b715 --- /dev/null +++ b/tests/dataset_api_tests.py @@ -0,0 +1,450 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for Superset""" +import json +from typing import List +from unittest.mock import patch + +import prison + +from superset import db, security_manager +from superset.commands.exceptions import ( + CreateFailedError, + DeleteFailedError, + UpdateFailedError, +) +from superset.connectors.sqla.models import SqlaTable +from superset.models.core import Database +from superset.utils.core import get_example_database + +from .base_tests import SupersetTestCase + + +class DatasetApiTests(SupersetTestCase): + @staticmethod + def insert_dataset( + table_name: str, schema: str, owners: List[int], database: Database + ) -> SqlaTable: + obj_owners = list() + for owner in owners: + user = db.session.query(security_manager.user_model).get(owner) + obj_owners.append(user) + table = SqlaTable( + table_name=table_name, schema=schema, owners=obj_owners, database=database + ) + db.session.add(table) + db.session.commit() + return table + + def test_get_dataset_list(self): + """ + Dataset API: Test get dataset list + """ + example_db = get_example_database() + self.login(username="admin") + arguments = { + "filters": [ + {"col": "database", "opr": "rel_o_m", "value": f"{example_db.id}"}, + {"col": "table_name", "opr": "eq", "value": f"birth_names"}, + ] + } + uri = f"api/v1/dataset/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["count"], 1) + expected_columns = [ + "changed_by", + "changed_on", + "database_name", + "schema", + "table_name", + ] + self.assertEqual(sorted(list(response["result"][0].keys())), expected_columns) + + def test_get_dataset_list_gamma(self): + """ + Dataset API: Test get dataset list gamma + """ + example_db = get_example_database() + self.login(username="gamma") + uri = "api/v1/dataset/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["result"], []) + + def test_get_dataset_related_database_gamma(self): + """ + Dataset API: Test get dataset related databases gamma + """ + example_db = get_example_database() + self.login(username="gamma") + uri = "api/v1/dataset/related/database" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["count"], 0) + self.assertEqual(response["result"], []) + + def test_get_dataset_item(self): + """ + Dataset API: Test get dataset item + """ + example_db = get_example_database() + table = ( + db.session.query(SqlaTable) + .filter_by(database=example_db, table_name="birth_names") + .one() + ) + self.login(username="admin") + uri = f"api/v1/dataset/{table.id}" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + expected_result = { + "cache_timeout": None, + "database": {"database_name": "examples", "id": 1}, + "default_endpoint": None, + "description": None, + "fetch_values_predicate": None, + "filter_select_enabled": True, + "is_sqllab_view": False, + "main_dttm_col": "ds", + "offset": 0, + "owners": [], + "schema": None, + "sql": None, + "table_name": "birth_names", + "template_params": None, + } + self.assertEqual(response["result"], expected_result) + + def test_get_dataset_info(self): + """ + Dataset API: Test get dataset info + """ + self.login(username="admin") + uri = "api/v1/dataset/_info" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + + def test_create_dataset_item(self): + """ + Dataset API: Test create dataset item + """ + example_db = get_example_database() + self.login(username="admin") + table_data = { + "database": example_db.id, + "schema": "", + "table_name": "ab_permission", + } + uri = "api/v1/dataset/" + rv = self.client.post(uri, json=table_data) + self.assertEqual(rv.status_code, 201) + data = json.loads(rv.data.decode("utf-8")) + model = db.session.query(SqlaTable).get(data.get("id")) + self.assertEqual(model.table_name, table_data["table_name"]) + self.assertEqual(model.database_id, table_data["database"]) + db.session.delete(model) + db.session.commit() + + def test_create_dataset_item_gamma(self): + """ + Dataset API: Test create dataset item gamma + """ + self.login(username="gamma") + example_db = get_example_database() + table_data = { + "database": example_db.id, + "schema": "", + "table_name": "ab_permission", + } + uri = "api/v1/dataset/" + rv = self.client.post(uri, json=table_data) + self.assertEqual(rv.status_code, 401) + + def test_create_dataset_item_owner(self): + """ + Dataset API: Test create item owner + """ + example_db = get_example_database() + self.login(username="alpha") + admin = self.get_user("admin") + alpha = self.get_user("alpha") + + table_data = { + "database": example_db.id, + "schema": "", + "table_name": "ab_permission", + "owners": [admin.id], + } + uri = "api/v1/dataset/" + rv = self.client.post(uri, json=table_data) + self.assertEqual(rv.status_code, 201) + data = json.loads(rv.data.decode("utf-8")) + model = db.session.query(SqlaTable).get(data.get("id")) + self.assertIn(admin, model.owners) + self.assertIn(alpha, model.owners) + db.session.delete(model) + db.session.commit() + + def test_create_dataset_item_owners_invalid(self): + """ + Dataset API: Test create dataset item owner invalid + """ + admin = self.get_user("admin") + example_db = get_example_database() + self.login(username="admin") + table_data = { + "database": example_db.id, + "schema": "", + "table_name": "ab_permission", + "owners": [admin.id, 1000], + } + uri = f"api/v1/dataset/" + rv = self.client.post(uri, json=table_data) + self.assertEqual(rv.status_code, 422) + data = json.loads(rv.data.decode("utf-8")) + expected_result = {"message": {"owners": ["Owners are invalid"]}} + self.assertEqual(data, expected_result) + + def test_create_dataset_validate_uniqueness(self): + """ + Dataset API: Test create dataset validate table uniqueness + """ + example_db = get_example_database() + self.login(username="admin") + table_data = { + "database": example_db.id, + "schema": "", + "table_name": "birth_names", + } + uri = "api/v1/dataset/" + rv = self.client.post(uri, json=table_data) + self.assertEqual(rv.status_code, 422) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + data, {"message": {"table_name": ["Datasource birth_names already exists"]}} + ) + + def test_create_dataset_validate_database(self): + """ + Dataset API: Test create dataset validate database exists + """ + self.login(username="admin") + table_data = {"database": 1000, "schema": "", "table_name": "birth_names"} + uri = "api/v1/dataset/" + rv = self.client.post(uri, json=table_data) + self.assertEqual(rv.status_code, 422) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(data, {"message": {"database": ["Database does not exist"]}}) + + def test_create_dataset_validate_tables_exists(self): + """ + Dataset API: Test create dataset validate table exists + """ + example_db = get_example_database() + self.login(username="admin") + table_data = { + "database": example_db.id, + "schema": "", + "table_name": "does_not_exist", + } + uri = "api/v1/dataset/" + rv = self.client.post(uri, json=table_data) + self.assertEqual(rv.status_code, 422) + + @patch("superset.datasets.dao.DatasetDAO.create") + def test_create_dataset_sqlalchemy_error(self, mock_dao_create): + """ + Dataset API: Test create dataset sqlalchemy error + """ + mock_dao_create.side_effect = CreateFailedError() + self.login(username="admin") + example_db = get_example_database() + dataset_data = { + "database": example_db.id, + "schema": "", + "table_name": "ab_permission", + } + uri = "api/v1/dataset/" + rv = self.client.post(uri, json=dataset_data) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 422) + self.assertEqual(data, {"message": "Dataset could not be created."}) + + def test_update_dataset_item(self): + """ + Dataset API: Test update dataset item + """ + table = self.insert_dataset("ab_permission", "", [], get_example_database()) + self.login(username="admin") + table_data = {"description": "changed_description"} + uri = f"api/v1/dataset/{table.id}" + rv = self.client.put(uri, json=table_data) + self.assertEqual(rv.status_code, 200) + model = db.session.query(SqlaTable).get(table.id) + self.assertEqual(model.description, table_data["description"]) + db.session.delete(table) + db.session.commit() + + def test_update_dataset_item_gamma(self): + """ + Dataset API: Test update dataset item gamma + """ + table = self.insert_dataset("ab_permission", "", [], get_example_database()) + self.login(username="gamma") + table_data = {"description": "changed_description"} + uri = f"api/v1/dataset/{table.id}" + rv = self.client.put(uri, json=table_data) + self.assertEqual(rv.status_code, 401) + db.session.delete(table) + db.session.commit() + + def test_update_dataset_item_not_owned(self): + """ + Dataset API: Test update dataset item not owned + """ + admin = self.get_user("admin") + table = self.insert_dataset( + "ab_permission", "", [admin.id], get_example_database() + ) + self.login(username="alpha") + table_data = {"description": "changed_description"} + uri = f"api/v1/dataset/{table.id}" + rv = self.client.put(uri, json=table_data) + self.assertEqual(rv.status_code, 403) + db.session.delete(table) + db.session.commit() + + def test_update_dataset_item_owners_invalid(self): + """ + Dataset API: Test update dataset item owner invalid + """ + admin = self.get_user("admin") + table = self.insert_dataset( + "ab_permission", "", [admin.id], get_example_database() + ) + self.login(username="admin") + table_data = {"description": "changed_description", "owners": [1000]} + uri = f"api/v1/dataset/{table.id}" + rv = self.client.put(uri, json=table_data) + self.assertEqual(rv.status_code, 422) + db.session.delete(table) + db.session.commit() + + def test_update_dataset_item_uniqueness(self): + """ + Dataset API: Test update dataset uniqueness + """ + admin = self.get_user("admin") + table = self.insert_dataset( + "ab_permission", "", [admin.id], get_example_database() + ) + self.login(username="admin") + table_data = {"table_name": "birth_names"} + uri = f"api/v1/dataset/{table.id}" + rv = self.client.put(uri, json=table_data) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 422) + expected_response = { + "message": {"table_name": ["Datasource birth_names already exists"]} + } + self.assertEqual(data, expected_response) + db.session.delete(table) + db.session.commit() + + @patch("superset.datasets.dao.DatasetDAO.update") + def test_update_dataset_sqlalchemy_error(self, mock_dao_update): + """ + Dataset API: Test update dataset sqlalchemy error + """ + mock_dao_update.side_effect = UpdateFailedError() + + table = self.insert_dataset("ab_permission", "", [], get_example_database()) + self.login(username="admin") + table_data = {"description": "changed_description"} + uri = f"api/v1/dataset/{table.id}" + rv = self.client.put(uri, json=table_data) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 422) + self.assertEqual(data, {"message": "Dataset could not be updated."}) + + def test_delete_dataset_item(self): + """ + Dataset API: Test delete dataset item + """ + admin = self.get_user("admin") + table = self.insert_dataset( + "ab_permission", "", [admin.id], get_example_database() + ) + self.login(username="admin") + uri = f"api/v1/dataset/{table.id}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 200) + + def test_delete_item_dataset_not_owned(self): + """ + Dataset API: Test delete item not owned + """ + admin = self.get_user("admin") + table = self.insert_dataset( + "ab_permission", "", [admin.id], get_example_database() + ) + self.login(username="alpha") + uri = f"api/v1/dataset/{table.id}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 403) + db.session.delete(table) + db.session.commit() + + def test_delete_dataset_item_not_authorized(self): + """ + Dataset API: Test delete item not authorized + """ + admin = self.get_user("admin") + table = self.insert_dataset( + "ab_permission", "", [admin.id], get_example_database() + ) + self.login(username="gamma") + uri = f"api/v1/dataset/{table.id}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 401) + db.session.delete(table) + db.session.commit() + + @patch("superset.datasets.dao.DatasetDAO.delete") + def test_delete_dataset_sqlalchemy_error(self, mock_dao_delete): + """ + Dataset API: Test delete dataset sqlalchemy error + """ + mock_dao_delete.side_effect = DeleteFailedError() + + admin = self.get_user("admin") + table = self.insert_dataset( + "ab_permission", "", [admin.id], get_example_database() + ) + self.login(username="admin") + uri = f"api/v1/dataset/{table.id}" + rv = self.client.delete(uri) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 422) + self.assertEqual(data, {"message": "Dataset could not be deleted."}) + db.session.delete(table) + db.session.commit()