diff --git a/superset/charts/commands/importers/__init__.py b/superset/charts/commands/importers/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/superset/charts/commands/importers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/charts/commands/importers/v1/__init__.py b/superset/charts/commands/importers/v1/__init__.py new file mode 100644 index 000000000..086a37070 --- /dev/null +++ b/superset/charts/commands/importers/v1/__init__.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict, List, Optional, Set + +from marshmallow import Schema, validate +from marshmallow.exceptions import ValidationError +from sqlalchemy.orm import Session + +from superset import db +from superset.charts.commands.importers.v1.utils import import_chart +from superset.charts.schemas import ImportV1ChartSchema +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.v1.utils import ( + load_metadata, + load_yaml, + METADATA_FILE_NAME, +) +from superset.databases.commands.importers.v1.utils import import_database +from superset.databases.schemas import ImportV1DatabaseSchema +from superset.datasets.commands.importers.v1.utils import import_dataset +from superset.datasets.schemas import ImportV1DatasetSchema +from superset.models.slice import Slice + +schemas: Dict[str, Schema] = { + "charts/": ImportV1ChartSchema(), + "datasets/": ImportV1DatasetSchema(), + "databases/": ImportV1DatabaseSchema(), +} + + +class ImportChartsCommand(BaseCommand): + + """Import charts""" + + # pylint: disable=unused-argument + def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + self.contents = contents + self._configs: Dict[str, Any] = {} + + def _import_bundle(self, session: Session) -> None: + # discover datasets associated with charts + dataset_uuids: Set[str] = set() + for file_name, config in self._configs.items(): + if file_name.startswith("charts/"): + dataset_uuids.add(config["dataset_uuid"]) + + # discover databases associated with datasets + database_uuids: Set[str] = set() + for file_name, config in self._configs.items(): + if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids: + database_uuids.add(config["database_uuid"]) + + # import related databases + database_ids: Dict[str, int] = {} + for file_name, config in self._configs.items(): + if file_name.startswith("databases/") and config["uuid"] in database_uuids: + database = import_database(session, config, overwrite=False) + database_ids[str(database.uuid)] = database.id + + # import datasets with the correct parent ref + dataset_info: Dict[str, Dict[str, Any]] = {} + for file_name, config in self._configs.items(): + if ( + file_name.startswith("datasets/") + and config["database_uuid"] in database_ids + ): + config["database_id"] = database_ids[config["database_uuid"]] + dataset = import_dataset(session, config, overwrite=False) + dataset_info[str(dataset.uuid)] = { + "datasource_id": dataset.id, + "datasource_type": "view" if dataset.is_sqllab_view else "table", + "datasource_name": dataset.table_name, + } + + # import charts with the correct parent ref + for file_name, config in self._configs.items(): + if ( + file_name.startswith("charts/") + and config["dataset_uuid"] in dataset_info + ): + # update datasource id, type, and name + config.update(dataset_info[config["dataset_uuid"]]) + import_chart(session, config, overwrite=True) + + def run(self) -> None: + self.validate() + + # rollback to prevent partial imports + try: + self._import_bundle(db.session) + db.session.commit() + except Exception as exc: + db.session.rollback() + raise exc + + def validate(self) -> None: + exceptions: List[ValidationError] = [] + + # verify that the metadata file is present and valid + try: + metadata: Optional[Dict[str, str]] = load_metadata(self.contents) + except ValidationError as exc: + exceptions.append(exc) + metadata = None + + for file_name, content in self.contents.items(): + prefix = file_name.split("/")[0] + schema = schemas.get(f"{prefix}/") + if schema: + try: + config = load_yaml(file_name, content) + schema.load(config) + self._configs[file_name] = config + except ValidationError as exc: + exc.messages = {file_name: exc.messages} + exceptions.append(exc) + + # validate that the type declared in METADATA_FILE_NAME is correct + if metadata: + type_validator = validate.Equal(Slice.__name__) + try: + type_validator(metadata["type"]) + except ValidationError as exc: + exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}} + exceptions.append(exc) + + if exceptions: + exception = CommandInvalidError("Error importing chart") + exception.add_list(exceptions) + raise exception diff --git a/superset/charts/commands/importers/v1/utils.py b/superset/charts/commands/importers/v1/utils.py new file mode 100644 index 000000000..b3d4237f2 --- /dev/null +++ b/superset/charts/commands/importers/v1/utils.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Any, Dict + +from sqlalchemy.orm import Session + +from superset.models.slice import Slice + + +def import_chart( + session: Session, config: Dict[str, Any], overwrite: bool = False +) -> Slice: + existing = session.query(Slice).filter_by(uuid=config["uuid"]).first() + if existing: + if not overwrite: + return existing + config["id"] = existing.id + + # TODO (betodealmeida): move this logic to import_from_dict + config["params"] = json.dumps(config["params"]) + + chart = Slice.import_from_dict(session, config, recursive=False) + if chart.id is None: + session.flush() + + return chart diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 099703816..ca1497a80 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -1118,6 +1118,14 @@ class GetFavStarIdsSchema(Schema): ) +class ImportV1ChartSchema(Schema): + params = fields.Dict() + cache_timeout = fields.Integer(allow_none=True) + uuid = fields.UUID(required=True) + version = fields.String(required=True) + dataset_uuid = fields.UUID(required=True) + + CHART_SCHEMAS = ( ChartDataQueryContextSchema, ChartDataResponseSchema, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index f0e5c0b5f..623bb07be 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -87,14 +87,6 @@ class ImportExportMixin: __mapper__: Mapper - @classmethod - def _parent_foreign_key_mappings(cls) -> Dict[str, str]: - """Get a mapping of foreign name to the local name of foreign keys""" - parent_rel = cls.__mapper__.relationships.get(cls.export_parent) - if parent_rel: - return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs} - return {} - @classmethod def _unique_constrains(cls) -> List[Set[str]]: """Get all (single column and multi column) unique constraints""" @@ -171,7 +163,7 @@ class ImportExportMixin: # Remove fields that should not get imported for k in list(dict_rep): - if k not in export_fields: + if k not in export_fields and k not in parent_refs: del dict_rep[k] if not parent: diff --git a/superset/models/slice.py b/superset/models/slice.py index 7254652f5..2fd55a7ac 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -93,6 +93,7 @@ class Slice( "params", "cache_timeout", ] + export_parent = "table" def __repr__(self) -> str: return self.slice_name or str(self.id) diff --git a/tests/charts/commands_tests.py b/tests/charts/commands_tests.py index 9189d4cc4..8523b8317 100644 --- a/tests/charts/commands_tests.py +++ b/tests/charts/commands_tests.py @@ -14,16 +14,31 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=no-self-use, invalid-name +import json from unittest.mock import patch +import pytest import yaml from superset import db, security_manager from superset.charts.commands.exceptions import ChartNotFoundError from superset.charts.commands.export import ExportChartsCommand +from superset.charts.commands.importers.v1 import ImportChartsCommand +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.exceptions import IncorrectVersionError +from superset.connectors.sqla.models import SqlaTable +from superset.models.core import Database from superset.models.slice import Slice from tests.base_tests import SupersetTestCase +from tests.fixtures.importexport import ( + chart_config, + chart_metadata_config, + database_config, + database_metadata_config, + dataset_config, +) class TestExportChartsCommand(SupersetTestCase): @@ -49,7 +64,7 @@ class TestExportChartsCommand(SupersetTestCase): "viz_type": "sankey", "params": { "collapsed_fieldsets": "", - "groupby": ["source", "target",], + "groupby": ["source", "target"], "metric": "sum__value", "row_limit": "5000", "slice_name": "Energy Sankey", @@ -100,3 +115,143 @@ class TestExportChartsCommand(SupersetTestCase): "version", "dataset_uuid", ] + + def test_import_v1_chart(self): + """Test that we can import a chart""" + contents = { + "metadata.yaml": yaml.safe_dump(chart_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + "charts/imported_chart.yaml": yaml.safe_dump(chart_config), + } + command = ImportChartsCommand(contents) + command.run() + + chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one() + assert json.loads(chart.params) == { + "color_picker": {"a": 1, "b": 135, "g": 122, "r": 0}, + "datasource": "12__table", + "js_columns": ["color"], + "js_data_mutator": "data => data.map(d => ({\\n ...d,\\n color: colors.hexToRGB(d.extraProps.color)\\n}));", + "js_onclick_href": "", + "js_tooltip": "", + "line_column": "path_json", + "line_type": "json", + "line_width": 150, + "mapbox_style": "mapbox://styles/mapbox/light-v9", + "reverse_long_lat": False, + "row_limit": 5000, + "slice_id": 43, + "time_grain_sqla": None, + "time_range": " : ", + "viewport": { + "altitude": 1.5, + "bearing": 0, + "height": 1094, + "latitude": 37.73671752604488, + "longitude": -122.18885402582598, + "maxLatitude": 85.05113, + "maxPitch": 60, + "maxZoom": 20, + "minLatitude": -85.05113, + "minPitch": 0, + "minZoom": 0, + "pitch": 0, + "width": 669, + "zoom": 9.51847667620428, + }, + "viz_type": "deck_path", + } + + dataset = ( + db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() + ) + assert dataset.table_name == "imported_dataset" + assert chart.table == dataset + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert database.database_name == "imported_database" + assert chart.table.database == database + + db.session.delete(chart) + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + + def test_import_v1_chart_multiple(self): + """Test that a dataset can be imported multiple times""" + contents = { + "metadata.yaml": yaml.safe_dump(chart_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + "charts/imported_chart.yaml": yaml.safe_dump(chart_config), + } + command = ImportChartsCommand(contents) + command.run() + command.run() + + dataset = ( + db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() + ) + charts = db.session.query(Slice).filter_by(datasource_id=dataset.id).all() + assert len(charts) == 1 + + database = dataset.database + + db.session.delete(charts[0]) + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + + def test_import_v1_chart_validation(self): + """Test different validations applied when importing a chart""" + # metadata.yaml must be present + contents = { + "databases/imported_database.yaml": yaml.safe_dump(database_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + "charts/imported_chart.yaml": yaml.safe_dump(chart_config), + } + command = ImportChartsCommand(contents) + with pytest.raises(IncorrectVersionError) as excinfo: + command.run() + assert str(excinfo.value) == "Missing metadata.yaml" + + # version should be 1.0.0 + contents["metadata.yaml"] = yaml.safe_dump( + { + "version": "2.0.0", + "type": "SqlaTable", + "timestamp": "2020-11-04T21:27:44.423819+00:00", + } + ) + command = ImportChartsCommand(contents) + with pytest.raises(IncorrectVersionError) as excinfo: + command.run() + assert str(excinfo.value) == "Must be equal to 1.0.0." + + # type should be Slice + contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config) + command = ImportChartsCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Error importing chart" + assert excinfo.value.normalized_messages() == { + "metadata.yaml": {"type": ["Must be equal to Slice."]} + } + + # must also validate datasets and databases + broken_config = database_config.copy() + del broken_config["database_name"] + contents["metadata.yaml"] = yaml.safe_dump(chart_metadata_config) + contents["databases/imported_database.yaml"] = yaml.safe_dump(broken_config) + command = ImportChartsCommand(contents) + with pytest.raises(CommandInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == "Error importing chart" + assert excinfo.value.normalized_messages() == { + "databases/imported_database.yaml": { + "database_name": ["Missing data for required field."], + } + } diff --git a/tests/datasets/commands_tests.py b/tests/datasets/commands_tests.py index 94b5e7928..a957ffc3b 100644 --- a/tests/datasets/commands_tests.py +++ b/tests/datasets/commands_tests.py @@ -26,9 +26,11 @@ from superset import db, security_manager from superset.commands.exceptions import CommandInvalidError from superset.commands.importers.exceptions import IncorrectVersionError from superset.connectors.sqla.models import SqlaTable +from superset.databases.commands.importers.v1 import ImportDatabasesCommand from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.commands.export import ExportDatasetsCommand from superset.datasets.commands.importers.v1 import ImportDatasetsCommand +from superset.models.core import Database from superset.utils.core import get_example_database from tests.base_tests import SupersetTestCase from tests.fixtures.importexport import ( @@ -326,7 +328,7 @@ class TestExportDatasetsCommand(SupersetTestCase): command.run() assert str(excinfo.value) == "Error importing dataset" assert excinfo.value.normalized_messages() == { - "metadata.yaml": {"type": ["Must be equal to SqlaTable."],} + "metadata.yaml": {"type": ["Must be equal to SqlaTable."]} } # must also validate databases @@ -343,3 +345,32 @@ class TestExportDatasetsCommand(SupersetTestCase): "database_name": ["Missing data for required field."], } } + + def test_import_v1_dataset_existing_database(self): + """Test that a dataset can be imported when the database already exists""" + # first import database... + contents = { + "metadata.yaml": yaml.safe_dump(database_metadata_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + } + command = ImportDatabasesCommand(contents) + command.run() + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert len(database.tables) == 0 + + # ...then dataset + contents = { + "metadata.yaml": yaml.safe_dump(dataset_metadata_config), + "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), + "databases/imported_database.yaml": yaml.safe_dump(database_config), + } + command = ImportDatasetsCommand(contents) + command.run() + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert len(database.tables) == 1 diff --git a/tests/fixtures/importexport.py b/tests/fixtures/importexport.py index 8b6400414..8312a81c6 100644 --- a/tests/fixtures/importexport.py +++ b/tests/fixtures/importexport.py @@ -30,6 +30,12 @@ dataset_metadata_config: Dict[str, Any] = { "timestamp": "2020-11-04T21:27:44.423819+00:00", } +chart_metadata_config: Dict[str, Any] = { + "version": "1.0.0", + "type": "Slice", + "timestamp": "2020-11-04T21:27:44.423819+00:00", +} + database_config: Dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, @@ -88,3 +94,44 @@ dataset_config: Dict[str, Any] = { "uuid": "10808100-158b-42c4-842e-f32b99d88dfb", "database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", } + +chart_config: Dict[str, Any] = { + "params": { + "color_picker": {"a": 1, "b": 135, "g": 122, "r": 0,}, + "datasource": "12__table", + "js_columns": ["color"], + "js_data_mutator": r"data => data.map(d => ({\n ...d,\n color: colors.hexToRGB(d.extraProps.color)\n}));", + "js_onclick_href": "", + "js_tooltip": "", + "line_column": "path_json", + "line_type": "json", + "line_width": 150, + "mapbox_style": "mapbox://styles/mapbox/light-v9", + "reverse_long_lat": False, + "row_limit": 5000, + "slice_id": 43, + "time_grain_sqla": None, + "time_range": " : ", + "viewport": { + "altitude": 1.5, + "bearing": 0, + "height": 1094, + "latitude": 37.73671752604488, + "longitude": -122.18885402582598, + "maxLatitude": 85.05113, + "maxPitch": 60, + "maxZoom": 20, + "minLatitude": -85.05113, + "minPitch": 0, + "minZoom": 0, + "pitch": 0, + "width": 669, + "zoom": 9.51847667620428, + }, + "viz_type": "deck_path", + }, + "cache_timeout": None, + "uuid": "0c23747a-6528-4629-97bf-e4b78d3b9df1", + "version": "1.0.0", + "dataset_uuid": "10808100-158b-42c4-842e-f32b99d88dfb", +}