diff --git a/superset/cli.py b/superset/cli.py index 2c32a2ba4..adff31f71 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -301,7 +301,7 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None: ) def import_datasources(path: str, sync: str, recursive: bool) -> None: """Import datasources from YAML""" - from superset.datasets.commands.importers.v0 import ImportDatasetsCommand + from superset.datasets.commands.importers.dispatcher import ImportDatasetsCommand sync_array = sync.split(",") sync_columns = "columns" in sync_array diff --git a/superset/constants.py b/superset/constants.py index ea14a38a0..65ef2b13a 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -50,6 +50,7 @@ class RouteMethod: # pylint: disable=too-few-public-methods # RestModelView specific EXPORT = "export" + IMPORT = "import_" GET = "get" GET_LIST = "get_list" POST = "post" diff --git a/superset/databases/api.py b/superset/databases/api.py index dc54c6b39..4d8a0d413 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -35,6 +35,7 @@ from sqlalchemy.exc import ( ) from superset import event_logger +from superset.commands.exceptions import CommandInvalidError from superset.constants import RouteMethod from superset.databases.commands.create import CreateDatabaseCommand from superset.databases.commands.delete import DeleteDatabaseCommand @@ -49,6 +50,7 @@ from superset.databases.commands.exceptions import ( DatabaseUpdateFailedError, ) from superset.databases.commands.export import ExportDatabasesCommand +from superset.databases.commands.importers.dispatcher import ImportDatabasesCommand from superset.databases.commands.test_connection import TestConnectionDatabaseCommand from superset.databases.commands.update import UpdateDatabaseCommand from superset.databases.dao import DatabaseDAO @@ -80,6 +82,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { RouteMethod.EXPORT, + RouteMethod.IMPORT, "table_metadata", "select_star", "schemas", @@ -722,3 +725,56 @@ class DatabaseRestApi(BaseSupersetModelRestApi): as_attachment=True, attachment_filename=filename, ) + + @expose("/import/", methods=["POST"]) + @protect() + @safe + @statsd_metrics + def import_(self) -> Response: + """Import database(s) with associated datasets + --- + post: + requestBody: + content: + application/zip: + schema: + type: string + format: binary + responses: + 200: + description: Database import result + content: + application/json: + schema: + type: object + properties: + message: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + upload = request.files.get("file") + if not upload: + return self.response_400() + with ZipFile(upload) as bundle: + contents = { + file_name: bundle.read(file_name).decode() + for file_name in bundle.namelist() + } + + command = ImportDatabasesCommand(contents) + try: + command.run() + return self.response(200, message="OK") + except CommandInvalidError as exc: + logger.warning("Import database failed") + return self.response_422(message=exc.normalized_messages()) + except Exception as exc: # pylint: disable=broad-except + logger.exception("Import database failed") + return self.response_500(message=str(exc)) diff --git a/superset/databases/commands/importers/dispatcher.py b/superset/databases/commands/importers/dispatcher.py new file mode 100644 index 000000000..a29f5ca10 --- /dev/null +++ b/superset/databases/commands/importers/dispatcher.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import Any, Dict + +from marshmallow.exceptions import ValidationError + +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.exceptions import IncorrectVersionError +from superset.databases.commands.importers import v1 + +logger = logging.getLogger(__name__) + +command_versions = [v1.ImportDatabasesCommand] + + +class ImportDatabasesCommand(BaseCommand): + """ + Import databases. + + This command dispatches the import to different versions of the command + until it finds one that matches. + """ + + # pylint: disable=unused-argument + def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + self.contents = contents + + def run(self) -> None: + # iterate over all commands until we find a version that can + # handle the contents + for version in command_versions: + command = version(self.contents) + try: + command.run() + return + except IncorrectVersionError: + # file is not handled by this command, skip + pass + except (CommandInvalidError, ValidationError) as exc: + # found right version, but file is invalid + logger.info("Command failed validation") + raise exc + except Exception as exc: + # validation succeeded but something went wrong + logger.exception("Error running import command") + raise exc + + raise CommandInvalidError("Could not find a valid command to import file") + + def validate(self) -> None: + pass diff --git a/superset/datasets/api.py b/superset/datasets/api.py index fbc01a137..1decc707a 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -28,6 +28,7 @@ from flask_babel import ngettext from marshmallow import ValidationError from superset import is_feature_enabled +from superset.commands.exceptions import CommandInvalidError from superset.connectors.sqla.models import SqlaTable from superset.constants import RouteMethod from superset.databases.filters import DatabaseFilter @@ -45,6 +46,7 @@ from superset.datasets.commands.exceptions import ( DatasetUpdateFailedError, ) from superset.datasets.commands.export import ExportDatasetsCommand +from superset.datasets.commands.importers.dispatcher import ImportDatasetsCommand from superset.datasets.commands.refresh import RefreshDatasetCommand from superset.datasets.commands.update import UpdateDatasetCommand from superset.datasets.dao import DatasetDAO @@ -76,6 +78,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): class_permission_name = "TableModelView" include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { RouteMethod.EXPORT, + RouteMethod.IMPORT, RouteMethod.RELATED, RouteMethod.DISTINCT, "bulk_delete", @@ -589,3 +592,56 @@ class DatasetRestApi(BaseSupersetModelRestApi): return self.response_403() except DatasetBulkDeleteFailedError as ex: return self.response_422(message=str(ex)) + + @expose("/import/", methods=["POST"]) + @protect() + @safe + @statsd_metrics + def import_(self) -> Response: + """Import dataset (s) with associated databases + --- + post: + requestBody: + content: + application/zip: + schema: + type: string + format: binary + responses: + 200: + description: Dataset import result + content: + application/json: + schema: + type: object + properties: + message: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + upload = request.files.get("file") + if not upload: + return self.response_400() + with ZipFile(upload) as bundle: + contents = { + file_name: bundle.read(file_name).decode() + for file_name in bundle.namelist() + } + + command = ImportDatasetsCommand(contents) + try: + command.run() + return self.response(200, message="OK") + except CommandInvalidError as exc: + logger.warning("Import dataset failed") + return self.response_422(message=exc.normalized_messages()) + except Exception as exc: # pylint: disable=broad-except + logger.exception("Import dataset failed") + return self.response_500(message=str(exc)) diff --git a/superset/datasets/commands/importers/dispatcher.py b/superset/datasets/commands/importers/dispatcher.py new file mode 100644 index 000000000..99a4c269e --- /dev/null +++ b/superset/datasets/commands/importers/dispatcher.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import Any, Dict + +from marshmallow.exceptions import ValidationError + +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.exceptions import IncorrectVersionError +from superset.datasets.commands.importers import v0, v1 + +logger = logging.getLogger(__name__) + +# list of different import formats supported; v0 should be last because +# the files are not versioned +command_versions = [ + v1.ImportDatasetsCommand, + v0.ImportDatasetsCommand, +] + + +class ImportDatasetsCommand(BaseCommand): + """ + Import datasets. + + This command dispatches the import to different versions of the command + until it finds one that matches. + """ + + # pylint: disable=unused-argument + def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + self.contents = contents + + def run(self) -> None: + # iterate over all commands until we find a version that can + # handle the contents + for version in command_versions: + command = version(self.contents) + try: + command.run() + return + except IncorrectVersionError: + # file is not handled by command, skip + pass + except (CommandInvalidError, ValidationError) as exc: + # found right version, but file is invalid + logger.info("Command failed validation") + raise exc + except Exception as exc: + # validation succeeded but something went wrong + logger.exception("Error running import command") + raise exc + + raise CommandInvalidError("Could not find a valid command to import file") + + def validate(self) -> None: + pass diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py index b8c0ab65a..d45c58c58 100644 --- a/superset/datasets/commands/importers/v0.py +++ b/superset/datasets/commands/importers/v0.py @@ -25,6 +25,7 @@ from sqlalchemy.orm.session import make_transient from superset import db from superset.commands.base import BaseCommand +from superset.commands.importers.exceptions import IncorrectVersionError from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.druid.models import ( DruidCluster, @@ -289,6 +290,7 @@ class ImportDatasetsCommand(BaseCommand): sync_metrics: bool = False, ): self.contents = contents + self._configs: Dict[str, Any] = {} self.sync = [] if sync_columns: @@ -299,15 +301,21 @@ class ImportDatasetsCommand(BaseCommand): def run(self) -> None: self.validate() - for file_name, content in self.contents.items(): + for file_name, config in self._configs.items(): logger.info("Importing dataset from file %s", file_name) - import_from_dict(db.session, yaml.safe_load(content), sync=self.sync) + import_from_dict(db.session, config, sync=self.sync) def validate(self) -> None: # ensure all files are YAML - for content in self.contents.values(): + for file_name, content in self.contents.items(): try: - yaml.safe_load(content) + config = yaml.safe_load(content) except yaml.parser.ParserError: logger.exception("Invalid YAML file") - raise + raise IncorrectVersionError(f"{file_name} is not a valid YAML file") + + # check for keys + if DATABASES_KEY not in config and DRUID_CLUSTERS_KEY not in config: + raise IncorrectVersionError(f"{file_name} has no valid keys") + + self._configs[file_name] = config diff --git a/superset/views/base_api.py b/superset/views/base_api.py index a83fcc671..1495a79b1 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -128,6 +128,7 @@ class BaseSupersetModelRestApi(ModelRestApi): "delete": "delete", "distinct": "list", "export": "mulexport", + "import_": "add", "get": "show", "get_list": "list", "info": "list", diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py index 2b8504daf..e38b64a7b 100644 --- a/tests/databases/api_tests.py +++ b/tests/databases/api_tests.py @@ -15,13 +15,15 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file +# pylint: disable=invalid-name, no-self-use, too-many-public-methods, too-many-arguments """Unit tests for Superset""" import json from io import BytesIO -from zipfile import is_zipfile +from zipfile import is_zipfile, ZipFile import prison import pytest +import yaml from sqlalchemy.sql import func @@ -31,6 +33,13 @@ from superset.models.core import Database from superset.utils.core import get_example_database, get_main_database from tests.base_tests import SupersetTestCase from tests.fixtures.certificates import ssl_certificate +from tests.fixtures.importexport import ( + database_config, + dataset_config, + database_metadata_config, + dataset_metadata_config, +) + from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_position from tests.test_app import app @@ -817,3 +826,71 @@ class TestDatabaseApi(SupersetTestCase): uri = f"api/v1/database/export/?q={prison.dumps(argument)}" rv = self.get_assert_metric(uri, "export") assert rv.status_code == 404 + + def test_import_database(self): + """ + Database API: Test import database + """ + self.login(username="admin") + uri = "api/v1/database/import/" + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_metadata_config).encode()) + with bundle.open("databases/imported_database.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open("datasets/import_dataset.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + + form_data = { + "file": (buf, "database_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert database.database_name == "imported_database" + + assert len(database.tables) == 1 + dataset = database.tables[0] + assert dataset.table_name == "imported_dataset" + assert str(dataset.uuid) == dataset_config["uuid"] + + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + + def test_import_database_invalid(self): + """ + Database API: Test import invalid database + """ + self.login(username="admin") + uri = "api/v1/database/import/" + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_metadata_config).encode()) + with bundle.open("databases/imported_database.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open("datasets/import_dataset.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + + form_data = { + "file": (buf, "database_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "message": {"metadata.yaml": {"type": ["Must be equal to Database."]}} + } diff --git a/tests/databases/commands_tests.py b/tests/databases/commands_tests.py index a88283f23..afaf50f95 100644 --- a/tests/databases/commands_tests.py +++ b/tests/databases/commands_tests.py @@ -432,7 +432,7 @@ class TestExportDatabasesCommand(SupersetTestCase): command.run() assert str(excinfo.value) == "Error importing database" assert excinfo.value.normalized_messages() == { - "metadata.yaml": {"type": ["Must be equal to Database."],} + "metadata.yaml": {"type": ["Must be equal to Database."]} } # must also validate datasets diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index 757ba77f5..59854c496 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=too-many-public-methods, invalid-name """Unit tests for Superset""" import json from io import BytesIO from typing import List, Optional from unittest.mock import patch -from zipfile import is_zipfile +from zipfile import is_zipfile, ZipFile import prison import pytest @@ -38,6 +39,12 @@ from superset.utils.core import backend, get_example_database, get_main_database from superset.utils.dict_import_export import export_to_dict from tests.base_tests import SupersetTestCase from tests.conftest import CTAS_SCHEMA_NAME +from tests.fixtures.importexport import ( + database_config, + database_metadata_config, + dataset_config, + dataset_metadata_config, +) class TestDatasetApi(SupersetTestCase): @@ -139,7 +146,7 @@ class TestDatasetApi(SupersetTestCase): arguments = { "filters": [ {"col": "database", "opr": "rel_o_m", "value": f"{example_db.id}"}, - {"col": "table_name", "opr": "eq", "value": f"birth_names"}, + {"col": "table_name", "opr": "eq", "value": "birth_names"}, ] } uri = f"api/v1/dataset/?q={prison.dumps(arguments)}" @@ -170,7 +177,6 @@ class TestDatasetApi(SupersetTestCase): """ Dataset API: Test get dataset list gamma """ - example_db = get_example_database() self.login(username="gamma") uri = "api/v1/dataset/" rv = self.get_assert_metric(uri, "get_list") @@ -423,7 +429,7 @@ class TestDatasetApi(SupersetTestCase): "table_name": "ab_permission", "owners": [admin.id, 1000], } - uri = f"api/v1/dataset/" + uri = "api/v1/dataset/" rv = self.post_assert_metric(uri, table_data, "post") assert rv.status_code == 422 data = json.loads(rv.data.decode("utf-8")) @@ -1169,3 +1175,95 @@ class TestDatasetApi(SupersetTestCase): data = json.loads(rv.data.decode("utf-8")) for table_name in self.fixture_tables_names: assert table_name in [ds["table_name"] for ds in data["result"]] + + def test_import_dataset(self): + """ + Dataset API: Test import dataset + """ + self.login(username="admin") + uri = "api/v1/dataset/import/" + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_metadata_config).encode()) + with bundle.open("databases/imported_database.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open("datasets/import_dataset.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + + form_data = { + "file": (buf, "dataset_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert database.database_name == "imported_database" + + assert len(database.tables) == 1 + dataset = database.tables[0] + assert dataset.table_name == "imported_dataset" + assert str(dataset.uuid) == dataset_config["uuid"] + + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + + def test_import_dataset_invalid(self): + """ + Dataset API: Test import invalid dataset + """ + self.login(username="admin") + uri = "api/v1/dataset/import/" + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_metadata_config).encode()) + with bundle.open("databases/imported_database.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open("datasets/import_dataset.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + + form_data = { + "file": (buf, "dataset_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "message": {"metadata.yaml": {"type": ["Must be equal to SqlaTable."]}} + } + + def test_import_dataset_invalid_v0_validation(self): + """ + Dataset API: Test import invalid dataset + """ + self.login(username="admin") + uri = "api/v1/dataset/import/" + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("databases/imported_database.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open("datasets/import_dataset.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + buf.seek(0) + + form_data = { + "file": (buf, "dataset_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == {"message": "Could not process entity"}