From 7bc353f8a8861ac45b3a64fc6463ed150f634831 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 16 Nov 2020 17:11:20 -0800 Subject: [PATCH] feat: new import commands for dataset and databases (#11670) * feat: commands for importing databases and datasets * Refactor code --- superset/cli.py | 5 +- superset/commands/importers/exceptions.py | 23 ++ superset/commands/importers/v1/__init__.py | 16 ++ superset/commands/importers/v1/utils.py | 67 ++++++ .../databases/commands/importers/__init__.py | 16 ++ .../commands/importers/v1/__init__.py | 116 ++++++++++ .../databases/commands/importers/v1/utils.py | 42 ++++ superset/databases/schemas.py | 21 ++ superset/datasets/commands/importers/v0.py | 14 +- .../commands/importers/v1/__init__.py | 121 ++++++++++ .../datasets/commands/importers/v1/utils.py | 42 ++++ superset/datasets/schemas.py | 45 ++++ superset/models/helpers.py | 2 +- tests/databases/commands_tests.py | 209 +++++++++++++++++- tests/datasets/commands_tests.py | 161 +++++++++++++- tests/fixtures/importexport.py | 90 ++++++++ 16 files changed, 983 insertions(+), 7 deletions(-) create mode 100644 superset/commands/importers/exceptions.py create mode 100644 superset/commands/importers/v1/__init__.py create mode 100644 superset/commands/importers/v1/utils.py create mode 100644 superset/databases/commands/importers/__init__.py create mode 100644 superset/databases/commands/importers/v1/__init__.py create mode 100644 superset/databases/commands/importers/v1/utils.py create mode 100644 superset/datasets/commands/importers/v1/__init__.py create mode 100644 superset/datasets/commands/importers/v1/utils.py create mode 100644 tests/fixtures/importexport.py diff --git a/superset/cli.py b/superset/cli.py index 5130dbfff..2c32a2ba4 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -304,6 +304,9 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None: from superset.datasets.commands.importers.v0 import ImportDatasetsCommand sync_array = sync.split(",") + sync_columns = "columns" in sync_array + sync_metrics = "metrics" in sync_array + path_object = Path(path) files: List[Path] = [] if path_object.is_file(): @@ -316,7 +319,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None: files.extend(path_object.rglob("*.yml")) contents = {path.name: open(path).read() for path in files} try: - ImportDatasetsCommand(contents, sync_array).run() + ImportDatasetsCommand(contents, sync_columns, sync_metrics).run() except Exception: # pylint: disable=broad-except logger.exception("Error when importing dataset") diff --git a/superset/commands/importers/exceptions.py b/superset/commands/importers/exceptions.py new file mode 100644 index 000000000..e79cc1c5d --- /dev/null +++ b/superset/commands/importers/exceptions.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from superset.commands.exceptions import CommandException + + +class IncorrectVersionError(CommandException): + status = 422 + message = "Import has incorrect version" diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/superset/commands/importers/v1/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py new file mode 100644 index 000000000..c9475d83e --- /dev/null +++ b/superset/commands/importers/v1/utils.py @@ -0,0 +1,67 @@ +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import Any, Dict + +import yaml +from marshmallow import fields, Schema, validate +from marshmallow.exceptions import ValidationError + +from superset.commands.importers.exceptions import IncorrectVersionError + +METADATA_FILE_NAME = "metadata.yaml" +IMPORT_VERSION = "1.0.0" + +logger = logging.getLogger(__name__) + + +class MetadataSchema(Schema): + version = fields.String(required=True, validate=validate.Equal(IMPORT_VERSION)) + type = fields.String(required=True) + timestamp = fields.DateTime() + + +def load_yaml(file_name: str, content: str) -> Dict[str, Any]: + """Try to load a YAML file""" + try: + return yaml.safe_load(content) + except yaml.parser.ParserError: + logger.exception("Invalid YAML in %s", METADATA_FILE_NAME) + raise ValidationError({file_name: "Not a valid YAML file"}) + + +def load_metadata(contents: Dict[str, str]) -> Dict[str, str]: + """Apply validation and load a metadata file""" + if METADATA_FILE_NAME not in contents: + # if the contents ahve no METADATA_FILE_NAME this is probably + # a original export without versioning that should not be + # handled by this command + raise IncorrectVersionError(f"Missing {METADATA_FILE_NAME}") + + metadata = load_yaml(METADATA_FILE_NAME, contents[METADATA_FILE_NAME]) + try: + MetadataSchema().load(metadata) + except ValidationError as exc: + # if the version doesn't match raise an exception so that the + # dispatcher can try a different command version + if "version" in exc.messages: + raise IncorrectVersionError(exc.messages["version"][0]) + + # otherwise we raise the validation error + exc.messages = {METADATA_FILE_NAME: exc.messages} + raise exc + + return metadata diff --git a/superset/databases/commands/importers/__init__.py b/superset/databases/commands/importers/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/superset/databases/commands/importers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/databases/commands/importers/v1/__init__.py b/superset/databases/commands/importers/v1/__init__.py new file mode 100644 index 000000000..cfe82cc48 --- /dev/null +++ b/superset/databases/commands/importers/v1/__init__.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List + +from marshmallow import Schema, validate +from marshmallow.exceptions import ValidationError +from sqlalchemy.orm import Session + +from superset import db +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.v1.utils import ( + load_metadata, + load_yaml, + METADATA_FILE_NAME, +) +from superset.databases.commands.importers.v1.utils import import_database +from superset.databases.schemas import ImportV1DatabaseSchema +from superset.datasets.commands.importers.v1.utils import import_dataset +from superset.datasets.schemas import ImportV1DatasetSchema +from superset.models.core import Database + +schemas: Dict[str, Schema] = { + "databases/": ImportV1DatabaseSchema(), + "datasets/": ImportV1DatasetSchema(), +} + + +class ImportDatabasesCommand(BaseCommand): + + """Import databases""" + + # pylint: disable=unused-argument + def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + self.contents = contents + self._configs: Dict[str, Any] = {} + + def _import_bundle(self, session: Session) -> None: + # first import databases + database_ids: Dict[str, int] = {} + for file_name, config in self._configs.items(): + if file_name.startswith("databases/"): + database = import_database(session, config, overwrite=True) + database_ids[str(database.uuid)] = database.id + + # import related datasets + for file_name, config in self._configs.items(): + if ( + file_name.startswith("datasets/") + and config["database_uuid"] in database_ids + ): + config["database_id"] = database_ids[config["database_uuid"]] + # overwrite=False prevents deleting any non-imported columns/metrics + import_dataset(session, config, overwrite=False) + + def run(self) -> None: + self.validate() + + # rollback to prevent partial imports + try: + self._import_bundle(db.session) + db.session.commit() + except Exception as exc: + db.session.rollback() + raise exc + + def validate(self) -> None: + exceptions: List[ValidationError] = [] + + # verify that the metadata file is present and valid + try: + metadata = load_metadata(self.contents) + except ValidationError as exc: + exceptions.append(exc) + metadata = None + + for file_name, content in self.contents.items(): + prefix = file_name.split("/")[0] + schema = schemas.get(f"{prefix}/") + if schema: + try: + config = load_yaml(file_name, content) + schema.load(config) + self._configs[file_name] = config + except ValidationError as exc: + exc.messages = {file_name: exc.messages} + exceptions.append(exc) + + # validate that the type declared in METADATA_FILE_NAME is correct + if metadata: + type_validator = validate.Equal(Database.__name__) + try: + type_validator(metadata["type"]) + except ValidationError as exc: + exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}} + exceptions.append(exc) + + if exceptions: + exception = CommandInvalidError("Error importing database") + exception.add_list(exceptions) + raise exception diff --git a/superset/databases/commands/importers/v1/utils.py b/superset/databases/commands/importers/v1/utils.py new file mode 100644 index 000000000..6e016d0e0 --- /dev/null +++ b/superset/databases/commands/importers/v1/utils.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Any, Dict + +from sqlalchemy.orm import Session + +from superset.models.core import Database + + +def import_database( + session: Session, config: Dict[str, Any], overwrite: bool = False +) -> Database: + existing = session.query(Database).filter_by(uuid=config["uuid"]).first() + if existing: + if not overwrite: + return existing + config["id"] = existing.id + + # TODO (betodealmeida): move this logic to import_from_dict + config["extra"] = json.dumps(config["extra"]) + + database = Database.import_from_dict(session, config, recursive=False) + if database.id is None: + session.flush() + + return database diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index e0f7242a9..d5e9ba3de 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -408,3 +408,24 @@ class DatabaseRelatedDashboards(Schema): class DatabaseRelatedObjectsResponse(Schema): charts = fields.Nested(DatabaseRelatedCharts) dashboards = fields.Nested(DatabaseRelatedDashboards) + + +class ImportV1DatabaseExtraSchema(Schema): + metadata_params = fields.Dict(keys=fields.Str(), values=fields.Raw()) + engine_params = fields.Dict(keys=fields.Str(), values=fields.Raw()) + metadata_cache_timeout = fields.Dict(keys=fields.Str(), values=fields.Integer()) + schemas_allowed_for_csv_upload = fields.List(fields.String) + + +class ImportV1DatabaseSchema(Schema): + database_name = fields.String(required=True) + sqlalchemy_uri = fields.String(required=True) + cache_timeout = fields.Integer(allow_none=True) + expose_in_sqllab = fields.Boolean() + allow_run_async = fields.Boolean() + allow_ctas = fields.Boolean() + allow_cvas = fields.Boolean() + allow_csv_upload = fields.Boolean() + extra = fields.Nested(ImportV1DatabaseExtraSchema) + uuid = fields.UUID(required=True) + version = fields.String(required=True) diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py index 5b3ed25d7..b8c0ab65a 100644 --- a/superset/datasets/commands/importers/v0.py +++ b/superset/datasets/commands/importers/v0.py @@ -282,9 +282,19 @@ class ImportDatasetsCommand(BaseCommand): in Superset. """ - def __init__(self, contents: Dict[str, str], sync: Optional[List[str]] = None): + def __init__( + self, + contents: Dict[str, str], + sync_columns: bool = False, + sync_metrics: bool = False, + ): self.contents = contents - self.sync = sync + + self.sync = [] + if sync_columns: + self.sync.append("columns") + if sync_metrics: + self.sync.append("metrics") def run(self) -> None: self.validate() diff --git a/superset/datasets/commands/importers/v1/__init__.py b/superset/datasets/commands/importers/v1/__init__.py new file mode 100644 index 000000000..b6f2649c6 --- /dev/null +++ b/superset/datasets/commands/importers/v1/__init__.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List, Set + +from marshmallow import Schema, validate +from marshmallow.exceptions import ValidationError +from sqlalchemy.orm import Session + +from superset import db +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.v1.utils import ( + load_metadata, + load_yaml, + METADATA_FILE_NAME, +) +from superset.connectors.sqla.models import SqlaTable +from superset.databases.commands.importers.v1.utils import import_database +from superset.databases.schemas import ImportV1DatabaseSchema +from superset.datasets.commands.importers.v1.utils import import_dataset +from superset.datasets.schemas import ImportV1DatasetSchema + +schemas: Dict[str, Schema] = { + "databases/": ImportV1DatabaseSchema(), + "datasets/": ImportV1DatasetSchema(), +} + + +class ImportDatasetsCommand(BaseCommand): + + """Import datasets""" + + # pylint: disable=unused-argument + def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + self.contents = contents + self._configs: Dict[str, Any] = {} + + def _import_bundle(self, session: Session) -> None: + # discover databases associated with datasets + database_uuids: Set[str] = set() + for file_name, config in self._configs.items(): + if file_name.startswith("datasets/"): + database_uuids.add(config["database_uuid"]) + + # import related databases + database_ids: Dict[str, int] = {} + for file_name, config in self._configs.items(): + if file_name.startswith("databases/") and config["uuid"] in database_uuids: + database = import_database(session, config, overwrite=False) + database_ids[str(database.uuid)] = database.id + + # import datasets with the correct parent ref + for file_name, config in self._configs.items(): + if ( + file_name.startswith("datasets/") + and config["database_uuid"] in database_ids + ): + config["database_id"] = database_ids[config["database_uuid"]] + import_dataset(session, config, overwrite=True) + + def run(self) -> None: + self.validate() + + # rollback to prevent partial imports + try: + self._import_bundle(db.session) + db.session.commit() + except Exception as exc: + db.session.rollback() + raise exc + + def validate(self) -> None: + exceptions: List[ValidationError] = [] + + # verify that the metadata file is present and valid + try: + metadata = load_metadata(self.contents) + except ValidationError as exc: + exceptions.append(exc) + metadata = None + + for file_name, content in self.contents.items(): + prefix = file_name.split("/")[0] + schema = schemas.get(f"{prefix}/") + if schema: + try: + config = load_yaml(file_name, content) + schema.load(config) + self._configs[file_name] = config + except ValidationError as exc: + exc.messages = {file_name: exc.messages} + exceptions.append(exc) + + # validate that the type declared in METADATA_FILE_NAME is correct + if metadata: + type_validator = validate.Equal(SqlaTable.__name__) + try: + type_validator(metadata["type"]) + except ValidationError as exc: + exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}} + exceptions.append(exc) + + if exceptions: + exception = CommandInvalidError("Error importing dataset") + exception.add_list(exceptions) + raise exception diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py new file mode 100644 index 000000000..99326f3c3 --- /dev/null +++ b/superset/datasets/commands/importers/v1/utils.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +from sqlalchemy.orm import Session + +from superset.connectors.sqla.models import SqlaTable + + +def import_dataset( + session: Session, config: Dict[str, Any], overwrite: bool = False +) -> SqlaTable: + existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first() + if existing: + if not overwrite: + return existing + config["id"] = existing.id + + # should we delete columns and metrics not present in the current import? + sync = ["columns", "metrics"] if overwrite else [] + + # import recursively to include columns and metrics + dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync) + if dataset.id is None: + session.flush() + + return dataset diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 70d965699..f32e8d57f 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -122,3 +122,48 @@ class DatasetRelatedDashboards(Schema): class DatasetRelatedObjectsResponse(Schema): charts = fields.Nested(DatasetRelatedCharts) dashboards = fields.Nested(DatasetRelatedDashboards) + + +class ImportV1ColumnSchema(Schema): + column_name = fields.String(required=True) + verbose_name = fields.String() + is_dttm = fields.Boolean() + is_active = fields.Boolean(allow_none=True) + type = fields.String(required=True) + groupby = fields.Boolean() + filterable = fields.Boolean() + expression = fields.String() + description = fields.String(allow_none=True) + python_date_format = fields.String(allow_none=True) + + +class ImportV1MetricSchema(Schema): + metric_name = fields.String(required=True) + verbose_name = fields.String() + metric_type = fields.String(allow_none=True) + expression = fields.String(required=True) + description = fields.String(allow_none=True) + d3format = fields.String(allow_none=True) + extra = fields.String(allow_none=True) + warning_text = fields.String(allow_none=True) + + +class ImportV1DatasetSchema(Schema): + table_name = fields.String(required=True) + main_dttm_col = fields.String(allow_none=True) + description = fields.String() + default_endpoint = fields.String() + offset = fields.Integer() + cache_timeout = fields.Integer() + schema = fields.String() + sql = fields.String() + params = fields.String(allow_none=True) + template_params = fields.String(allow_none=True) + filter_select_enabled = fields.Boolean() + fetch_values_predicate = fields.String(allow_none=True) + extra = fields.String(allow_none=True) + uuid = fields.UUID(required=True) + columns = fields.List(fields.Nested(ImportV1ColumnSchema)) + metrics = fields.List(fields.Nested(ImportV1MetricSchema)) + version = fields.String(required=True) + database_uuid = fields.UUID(required=True) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 14cad2eb2..f0e5c0b5f 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -163,7 +163,7 @@ class ImportExportMixin: if sync is None: sync = [] parent_refs = cls.parent_foreign_key_mappings() - export_fields = set(cls.export_fields) | set(parent_refs.keys()) + export_fields = set(cls.export_fields) | set(parent_refs.keys()) | {"uuid"} new_children = {c: dict_rep[c] for c in cls.export_children if c in dict_rep} unique_constrains = cls._unique_constrains() diff --git a/tests/databases/commands_tests.py b/tests/databases/commands_tests.py index bd4d3438d..a88283f23 100644 --- a/tests/databases/commands_tests.py +++ b/tests/databases/commands_tests.py @@ -14,16 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=no-self-use, invalid-name from unittest.mock import patch +import pytest import yaml -from superset import security_manager +from superset import db, security_manager +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.exceptions import IncorrectVersionError +from superset.connectors.sqla.models import SqlaTable from superset.databases.commands.exceptions import DatabaseNotFoundError from superset.databases.commands.export import ExportDatabasesCommand +from superset.databases.commands.importers.v1 import ImportDatabasesCommand +from superset.models.core import Database from superset.utils.core import backend, get_example_database from tests.base_tests import SupersetTestCase +from tests.fixtures.importexport import ( + database_config, + database_metadata_config, + dataset_config, + dataset_metadata_config, +) class TestExportDatabasesCommand(SupersetTestCase): @@ -265,3 +278,197 @@ class TestExportDatabasesCommand(SupersetTestCase): "uuid", "version", ] + + def test_import_v1_database(self): + """Test that a database can be imported""" + contents = { + "metadata.yaml": yaml.safe_dump(database_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + } + command = ImportDatabasesCommand(contents) + command.run() + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert database.allow_csv_upload + assert database.allow_ctas + assert database.allow_cvas + assert not database.allow_run_async + assert database.cache_timeout is None + assert database.database_name == "imported_database" + assert database.expose_in_sqllab + assert database.extra == "{}" + assert database.sqlalchemy_uri == "sqlite:///test.db" + + db.session.delete(database) + db.session.commit() + + def test_import_v1_database_multiple(self): + """Test that a database can be imported multiple times""" + num_databases = db.session.query(Database).count() + + contents = { + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "metadata.yaml": yaml.safe_dump(database_metadata_config), + } + command = ImportDatabasesCommand(contents) + + # import twice + command.run() + command.run() + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert database.allow_csv_upload + + # update allow_csv_upload to False + new_config = database_config.copy() + new_config["allow_csv_upload"] = False + contents = { + "databases/imported_database.yaml": yaml.safe_dump(new_config), + "metadata.yaml": yaml.safe_dump(database_metadata_config), + } + command = ImportDatabasesCommand(contents) + command.run() + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert not database.allow_csv_upload + + # test that only one database was created + new_num_databases = db.session.query(Database).count() + assert new_num_databases == num_databases + 1 + + db.session.delete(database) + db.session.commit() + + def test_import_v1_database_with_dataset(self): + """Test that a database can be imported with datasets""" + contents = { + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + "metadata.yaml": yaml.safe_dump(database_metadata_config), + } + command = ImportDatabasesCommand(contents) + command.run() + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert len(database.tables) == 1 + assert str(database.tables[0].uuid) == "10808100-158b-42c4-842e-f32b99d88dfb" + + db.session.delete(database.tables[0]) + db.session.delete(database) + db.session.commit() + + def test_import_v1_database_with_dataset_multiple(self): + """Test that a database can be imported multiple times w/o changing datasets""" + contents = { + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + "metadata.yaml": yaml.safe_dump(database_metadata_config), + } + command = ImportDatabasesCommand(contents) + command.run() + + dataset = ( + db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() + ) + assert dataset.offset == 66 + + new_config = dataset_config.copy() + new_config["offset"] = 67 + contents = { + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(new_config), + "metadata.yaml": yaml.safe_dump(database_metadata_config), + } + command = ImportDatabasesCommand(contents) + command.run() + + # the underlying dataset should not be modified by the second import, since + # we're importing a database, not a dataset + dataset = ( + db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() + ) + assert dataset.offset == 66 + + db.session.delete(dataset) + db.session.delete(dataset.database) + db.session.commit() + + def test_import_v1_database_validation(self): + """Test different validations applied when importing a database""" + # metadata.yaml must be present + contents = { + "databases/imported_database.yaml": yaml.safe_dump(database_config), + } + command = ImportDatabasesCommand(contents) + with pytest.raises(IncorrectVersionError) as excinfo: + command.run() + assert str(excinfo.value) == "Missing metadata.yaml" + + # version should be 1.0.0 + contents["metadata.yaml"] = yaml.safe_dump( + { + "version": "2.0.0", + "type": "Database", + "timestamp": "2020-11-04T21:27:44.423819+00:00", + } + ) + command = ImportDatabasesCommand(contents) + with pytest.raises(IncorrectVersionError) as excinfo: + command.run() + assert str(excinfo.value) == "Must be equal to 1.0.0." + + # type should be Database + contents["metadata.yaml"] = yaml.safe_dump(dataset_metadata_config) + command = ImportDatabasesCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Error importing database" + assert excinfo.value.normalized_messages() == { + "metadata.yaml": {"type": ["Must be equal to Database."],} + } + + # must also validate datasets + broken_config = dataset_config.copy() + del broken_config["table_name"] + contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config) + contents["datasets/imported_dataset.yaml"] = yaml.safe_dump(broken_config) + command = ImportDatabasesCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Error importing database" + assert excinfo.value.normalized_messages() == { + "datasets/imported_dataset.yaml": { + "table_name": ["Missing data for required field."], + } + } + + @patch("superset.databases.commands.importers.v1.import_dataset") + def test_import_v1_rollback(self, mock_import_dataset): + """Test than on an exception everything is rolled back""" + num_databases = db.session.query(Database).count() + + # raise an exception when importing the dataset, after the database has + # already been imported + mock_import_dataset.side_effect = Exception("A wild exception appears!") + + contents = { + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + "metadata.yaml": yaml.safe_dump(database_metadata_config), + } + command = ImportDatabasesCommand(contents) + with pytest.raises(Exception) as excinfo: + command.run() + assert str(excinfo.value) == "A wild exception appears!" + + # verify that the database was not added + new_num_databases = db.session.query(Database).count() + assert new_num_databases == num_databases diff --git a/tests/datasets/commands_tests.py b/tests/datasets/commands_tests.py index 17afe1266..94b5e7928 100644 --- a/tests/datasets/commands_tests.py +++ b/tests/datasets/commands_tests.py @@ -14,18 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=no-self-use, invalid-name from operator import itemgetter from unittest.mock import patch +import pytest import yaml -from superset import security_manager +from superset import db, security_manager +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.exceptions import IncorrectVersionError from superset.connectors.sqla.models import SqlaTable from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.commands.export import ExportDatasetsCommand -from superset.utils.core import backend, get_example_database +from superset.datasets.commands.importers.v1 import ImportDatasetsCommand +from superset.utils.core import get_example_database from tests.base_tests import SupersetTestCase +from tests.fixtures.importexport import ( + database_config, + database_metadata_config, + dataset_config, + dataset_metadata_config, +) class TestExportDatasetsCommand(SupersetTestCase): @@ -186,3 +197,149 @@ class TestExportDatasetsCommand(SupersetTestCase): "version", "database_uuid", ] + + def test_import_v1_dataset(self): + """Test that we can import a dataset""" + contents = { + "metadata.yaml": yaml.safe_dump(dataset_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + } + command = ImportDatasetsCommand(contents) + command.run() + + dataset = ( + db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() + ) + assert dataset.table_name == "imported_dataset" + assert dataset.main_dttm_col is None + assert dataset.description == "This is a dataset that was exported" + assert dataset.default_endpoint == "" + assert dataset.offset == 66 + assert dataset.cache_timeout == 55 + assert dataset.schema == "" + assert dataset.sql == "" + assert dataset.params is None + assert dataset.template_params is None + assert dataset.filter_select_enabled + assert dataset.fetch_values_predicate is None + assert dataset.extra is None + + # database is also imported + assert str(dataset.database.uuid) == "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89" + + assert len(dataset.metrics) == 1 + metric = dataset.metrics[0] + assert metric.metric_name == "count" + assert metric.verbose_name == "" + assert metric.metric_type is None + assert metric.expression == "count(1)" + assert metric.description is None + assert metric.d3format is None + assert metric.extra is None + assert metric.warning_text is None + + assert len(dataset.columns) == 1 + column = dataset.columns[0] + assert column.column_name == "cnt" + assert column.verbose_name == "Count of something" + assert not column.is_dttm + assert column.is_active # imported columns are set to active + assert column.type == "NUMBER" + assert not column.groupby + assert column.filterable + assert column.expression == "" + assert column.description is None + assert column.python_date_format is None + + db.session.delete(dataset) + db.session.delete(dataset.database) + db.session.commit() + + def test_import_v1_dataset_multiple(self): + """Test that a dataset can be imported multiple times""" + contents = { + "metadata.yaml": yaml.safe_dump(dataset_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + } + command = ImportDatasetsCommand(contents) + command.run() + command.run() + dataset = ( + db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() + ) + assert dataset.table_name == "imported_dataset" + + # test that columns and metrics sync, ie, old ones not the import + # are removed + new_config = dataset_config.copy() + new_config["metrics"][0]["metric_name"] = "count2" + new_config["columns"][0]["column_name"] = "cnt2" + contents = { + "metadata.yaml": yaml.safe_dump(dataset_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(new_config), + } + command = ImportDatasetsCommand(contents) + command.run() + dataset = ( + db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() + ) + assert len(dataset.metrics) == 1 + assert dataset.metrics[0].metric_name == "count2" + assert len(dataset.columns) == 1 + assert dataset.columns[0].column_name == "cnt2" + + db.session.delete(dataset) + db.session.delete(dataset.database) + db.session.commit() + + def test_import_v1_dataset_validation(self): + """Test different validations applied when importing a dataset""" + # metadata.yaml must be present + contents = { + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + } + command = ImportDatasetsCommand(contents) + with pytest.raises(IncorrectVersionError) as excinfo: + command.run() + assert str(excinfo.value) == "Missing metadata.yaml" + + # version should be 1.0.0 + contents["metadata.yaml"] = yaml.safe_dump( + { + "version": "2.0.0", + "type": "SqlaTable", + "timestamp": "2020-11-04T21:27:44.423819+00:00", + } + ) + command = ImportDatasetsCommand(contents) + with pytest.raises(IncorrectVersionError) as excinfo: + command.run() + assert str(excinfo.value) == "Must be equal to 1.0.0." + + # type should be SqlaTable + contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config) + command = ImportDatasetsCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Error importing dataset" + assert excinfo.value.normalized_messages() == { + "metadata.yaml": {"type": ["Must be equal to SqlaTable."],} + } + + # must also validate databases + broken_config = database_config.copy() + del broken_config["database_name"] + contents["metadata.yaml"] = yaml.safe_dump(dataset_metadata_config) + contents["databases/imported_database.yaml"] = yaml.safe_dump(broken_config) + command = ImportDatasetsCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Error importing dataset" + assert excinfo.value.normalized_messages() == { + "databases/imported_database.yaml": { + "database_name": ["Missing data for required field."], + } + } diff --git a/tests/fixtures/importexport.py b/tests/fixtures/importexport.py new file mode 100644 index 000000000..8b6400414 --- /dev/null +++ b/tests/fixtures/importexport.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +# example YAML files +database_metadata_config: Dict[str, Any] = { + "version": "1.0.0", + "type": "Database", + "timestamp": "2020-11-04T21:27:44.423819+00:00", +} + +dataset_metadata_config: Dict[str, Any] = { + "version": "1.0.0", + "type": "SqlaTable", + "timestamp": "2020-11-04T21:27:44.423819+00:00", +} + +database_config: Dict[str, Any] = { + "allow_csv_upload": True, + "allow_ctas": True, + "allow_cvas": True, + "allow_run_async": False, + "cache_timeout": None, + "database_name": "imported_database", + "expose_in_sqllab": True, + "extra": {}, + "sqlalchemy_uri": "sqlite:///test.db", + "uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", + "version": "1.0.0", +} + +dataset_config: Dict[str, Any] = { + "table_name": "imported_dataset", + "main_dttm_col": None, + "description": "This is a dataset that was exported", + "default_endpoint": "", + "offset": 66, + "cache_timeout": 55, + "schema": "", + "sql": "", + "params": None, + "template_params": None, + "filter_select_enabled": True, + "fetch_values_predicate": None, + "extra": None, + "metrics": [ + { + "metric_name": "count", + "verbose_name": "", + "metric_type": None, + "expression": "count(1)", + "description": None, + "d3format": None, + "extra": None, + "warning_text": None, + }, + ], + "columns": [ + { + "column_name": "cnt", + "verbose_name": "Count of something", + "is_dttm": False, + "is_active": None, + "type": "NUMBER", + "groupby": False, + "filterable": True, + "expression": "", + "description": None, + "python_date_format": None, + }, + ], + "version": "1.0.0", + "uuid": "10808100-158b-42c4-842e-f32b99d88dfb", + "database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", +}