diff --git a/requirements/testing.in b/requirements/testing.in index ec3489196..731d60aae 100644 --- a/requirements/testing.in +++ b/requirements/testing.in @@ -16,8 +16,9 @@ # -r development.in -r integration.in -flask-testing docker +flask-testing +freezegun ipdb # pinning ipython as pip-compile-multi was bringing higher version # of the ipython that was not found in CI diff --git a/requirements/testing.txt b/requirements/testing.txt index 0d496c808..08e9b6f4c 100644 --- a/requirements/testing.txt +++ b/requirements/testing.txt @@ -1,4 +1,4 @@ -# SHA1:0e68e30f4e1bc76d0ec05267a1e38451c3901384 +# SHA1:b16b83f856b2dbc53535f71414f5c3e8dfa838e0 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -14,31 +14,30 @@ backcall==0.2.0 # via ipython coverage==5.3 # via pytest-cov docker==4.3.1 # via -r requirements/testing.in flask-testing==0.8.0 # via -r requirements/testing.in +freezegun==1.0.0 # via -r requirements/testing.in iniconfig==1.0.1 # via pytest -ipdb==0.13.3 # via -r requirements/testing.in +ipdb==0.13.4 # via -r requirements/testing.in ipython-genutils==0.2.0 # via traitlets ipython==7.16.1 # via -r requirements/testing.in, ipdb -isort==5.5.3 # via pylint +isort==5.6.4 # via pylint jedi==0.17.2 # via ipython lazy-object-proxy==1.4.3 # via astroid mccabe==0.6.1 # via pylint -more-itertools==8.5.0 # via pytest openapi-spec-validator==0.2.9 # via -r requirements/testing.in parameterized==0.7.4 # via -r requirements/testing.in parso==0.7.1 # via jedi pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -prompt-toolkit==3.0.7 # via ipython +prompt-toolkit==3.0.8 # via ipython ptyprocess==0.6.0 # via pexpect pygments==2.7.1 # via ipython pyhive[hive,presto]==0.6.3 # via -r requirements/development.in, -r requirements/testing.in pylint==2.6.0 # via -r requirements/testing.in pytest-cov==2.10.1 # via -r requirements/testing.in -pytest==6.0.2 # via -r requirements/testing.in, pytest-cov +pytest==6.1.1 # via -r requirements/testing.in, pytest-cov redis==3.5.3 # via -r requirements/testing.in statsd==3.3.0 # via -r requirements/testing.in traitlets==5.0.4 # via ipython -typed-ast==1.4.1 # via astroid wcwidth==0.2.5 # via prompt-toolkit websocket-client==0.57.0 # via docker wrapt==1.12.1 # via astroid diff --git a/setup.cfg b/setup.cfg index 8e90ce560..80b13d611 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false diff --git a/superset/databases/api.py b/superset/databases/api.py index 59c575eff..ca1eb3bc7 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -15,9 +15,12 @@ # specific language governing permissions and limitations # under the License. import logging +from datetime import datetime +from io import BytesIO from typing import Any, Optional +from zipfile import ZipFile -from flask import g, request, Response +from flask import g, request, Response, send_file from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import gettext as _ @@ -43,6 +46,7 @@ from superset.databases.commands.exceptions import ( DatabaseSecurityUnsafeError, DatabaseUpdateFailedError, ) +from superset.databases.commands.export import ExportDatabasesCommand from superset.databases.commands.test_connection import TestConnectionDatabaseCommand from superset.databases.commands.update import UpdateDatabaseCommand from superset.databases.dao import DatabaseDAO @@ -54,6 +58,7 @@ from superset.databases.schemas import ( DatabasePutSchema, DatabaseRelatedObjectsResponse, DatabaseTestConnectionSchema, + get_export_ids_schema, SchemasResponseSchema, SelectStarResponseSchema, TableMetadataResponseSchema, @@ -72,6 +77,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): datamodel = SQLAInterface(Database) include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { + RouteMethod.EXPORT, "table_metadata", "select_star", "schemas", @@ -653,3 +659,61 @@ class DatabaseRestApi(BaseSupersetModelRestApi): charts={"count": len(charts), "result": charts}, dashboards={"count": len(dashboards), "result": dashboards}, ) + + @expose("/export/", methods=["GET"]) + @protect() + @safe + @statsd_metrics + @rison(get_export_ids_schema) + def export(self, **kwargs: Any) -> Response: + """Export database(s) with associated datasets + --- + get: + description: Download database(s) and associated dataset(s) as a zip file + parameters: + - in: query + name: q + content: + application/json: + schema: + type: array + items: + type: integer + responses: + 200: + description: A zip file with database(s) and dataset(s) as YAML + content: + application/zip: + schema: + type: string + format: binary + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + requested_ids = kwargs["rison"] + timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") + root = f"database_export_{timestamp}" + filename = f"{root}.zip" + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + try: + for file_name, file_content in ExportDatabasesCommand( + requested_ids + ).run(): + with bundle.open(f"{root}/{file_name}", "w") as fp: + fp.write(file_content.encode()) + except DatabaseNotFoundError: + return self.response_404() + buf.seek(0) + + return send_file( + buf, + mimetype="application/zip", + as_attachment=True, + attachment_filename=filename, + ) diff --git a/superset/databases/commands/exceptions.py b/superset/databases/commands/exceptions.py index 51d1660ca..1ef0a7944 100644 --- a/superset/databases/commands/exceptions.py +++ b/superset/databases/commands/exceptions.py @@ -28,7 +28,7 @@ from superset.security.analytics_db_safety import DBSecurityException class DatabaseInvalidError(CommandInvalidError): - message = _("Dashboard parameters are invalid.") + message = _("Database parameters are invalid.") class DatabaseExistsValidationError(ValidationError): diff --git a/superset/databases/commands/export.py b/superset/databases/commands/export.py new file mode 100644 index 000000000..ac2410c88 --- /dev/null +++ b/superset/databases/commands/export.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file + +import json +from typing import Iterator, List, Tuple + +import yaml + +from superset.commands.base import BaseCommand +from superset.databases.commands.exceptions import DatabaseNotFoundError +from superset.databases.dao import DatabaseDAO +from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize +from superset.models.core import Database + + +class ExportDatabasesCommand(BaseCommand): + def __init__(self, database_ids: List[int]): + self.database_ids = database_ids + + # this will be set when calling validate() + self._models: List[Database] = [] + + @staticmethod + def export_database(database: Database) -> Iterator[Tuple[str, str]]: + name = sanitize(database.database_name) + file_name = f"databases/{name}.yaml" + + payload = database.export_to_dict( + recursive=False, + include_parent_ref=False, + include_defaults=True, + export_uuids=True, + ) + # TODO (betodealmeida): move this logic to export_to_dict once this + # becomes the default export endpoint + if "extra" in payload: + try: + payload["extra"] = json.loads(payload["extra"]) + except json.decoder.JSONDecodeError: + pass + + payload["version"] = IMPORT_EXPORT_VERSION + + file_content = yaml.safe_dump(payload, sort_keys=False) + yield file_name, file_content + + # TODO (betodealmeida): reuse logic from ExportDatasetCommand once + # it's implemented + for dataset in database.tables: + name = sanitize(dataset.table_name) + file_name = f"datasets/{name}.yaml" + + payload = dataset.export_to_dict( + recursive=True, + include_parent_ref=False, + include_defaults=True, + export_uuids=True, + ) + payload["version"] = IMPORT_EXPORT_VERSION + payload["database_uuid"] = str(database.uuid) + + file_content = yaml.safe_dump(payload, sort_keys=False) + yield file_name, file_content + + def run(self) -> Iterator[Tuple[str, str]]: + self.validate() + + for database in self._models: + yield from self.export_database(database) + + def validate(self) -> None: + self._models = DatabaseDAO.find_by_ids(self.database_ids) + if len(self._models) != len(self.database_ids): + raise DatabaseNotFoundError() diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 7dd7bc41f..4bcca2153 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -109,6 +109,7 @@ extra_description = markdown( "whether or not the Explore button in SQL Lab results is shown.", True, ) +get_export_ids_schema = {"type": "array", "items": {"type": "integer"}} sqlalchemy_uri_description = markdown( "Refer to the " "[SqlAlchemy docs]" diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 08301a7fe..2d8b432db 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -52,6 +52,23 @@ def json_to_dict(json_str: str) -> Dict[Any, Any]: return {} +def convert_uuids(obj: Any) -> Any: + """ + Convert UUID objects to str so we can use yaml.safe_dump + """ + if isinstance(obj, uuid.UUID): + return str(obj) + + if isinstance(obj, list): + return [convert_uuids(el) for el in obj] + + if isinstance(obj, dict): + return {k: convert_uuids(v) for k, v in obj.items()} + + return obj + + +# TODO (betodealmeida): rename to ImportExportMixin class ImportMixin: uuid = sa.Column( UUIDType(binary=True), primary_key=False, unique=True, default=uuid.uuid4 @@ -247,8 +264,15 @@ class ImportMixin: recursive: bool = True, include_parent_ref: bool = False, include_defaults: bool = False, + export_uuids: bool = False, ) -> Dict[Any, Any]: """Export obj to dictionary""" + export_fields = set(self.export_fields) + if export_uuids: + export_fields.add("uuid") + if "id" in export_fields: + export_fields.remove("id") + cls = self.__class__ parent_excludes = set() if recursive and not include_parent_ref: @@ -259,7 +283,7 @@ class ImportMixin: c.name: getattr(self, c.name) for c in cls.__table__.columns # type: ignore if ( - c.name in self.export_fields + c.name in export_fields and c.name not in parent_excludes and ( include_defaults @@ -270,6 +294,13 @@ class ImportMixin: ) ) } + + # sort according to export_fields using DSU (decorate, sort, undecorate) + order = {field: i for i, field in enumerate(self.export_fields)} + decorated_keys = [(order.get(k, len(order)), k) for k in dict_rep] + decorated_keys.sort() + dict_rep = {k: dict_rep[k] for _, k in decorated_keys} + if recursive: for cld in self.export_children: # sorting to make lists of children stable @@ -285,7 +316,7 @@ class ImportMixin: key=lambda k: sorted(str(k.items())), ) - return dict_rep + return convert_uuids(dict_rep) def override(self, obj: Any) -> None: """Overrides the plain fields of the dashboard.""" diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index 4d9e0496b..732acbf90 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import logging +import re +import unicodedata from typing import Any, Dict, List, Optional from sqlalchemy.orm import Session @@ -22,6 +24,7 @@ from sqlalchemy.orm import Session from superset.connectors.druid.models import DruidCluster from superset.models.core import Database +IMPORT_EXPORT_VERSION = "1.0.0" DATABASES_KEY = "databases" DRUID_CLUSTERS_KEY = "druid_clusters" logger = logging.getLogger(__name__) @@ -95,3 +98,18 @@ def import_from_dict( session.commit() else: logger.info("Supplied object is not a dictionary.") + + +def strip_accents(text: str) -> str: + text = unicodedata.normalize("NFD", text).encode("ascii", "ignore").decode("utf-8") + + return str(text) + + +def sanitize(name: str) -> str: + """Sanitize a post title into a directory name.""" + name = name.lower().replace(" ", "_") + name = re.sub(r"[^\w]", "", name) + name = strip_accents(name) + + return name diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index ed4003b78..c7023a825 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -962,7 +962,6 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): } ] rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") - print(rv.data) self.assertEqual(rv.status_code, 200) response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] diff --git a/tests/dashboards/api_tests.py b/tests/dashboards/api_tests.py index 09c14a0cd..6fc0385cc 100644 --- a/tests/dashboards/api_tests.py +++ b/tests/dashboards/api_tests.py @@ -24,6 +24,7 @@ import prison from sqlalchemy.sql import func import tests.test_app +from freezegun import freeze_time from sqlalchemy import and_ from superset import db, security_manager from superset.models.dashboard import Dashboard @@ -955,12 +956,14 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin): self.login(username="admin") argument = [1, 2] uri = f"api/v1/dashboard/export/?q={prison.dumps(argument)}" - rv = self.get_assert_metric(uri, "export") - self.assertEqual(rv.status_code, 200) - self.assertEqual( - rv.headers["Content-Disposition"], - generate_download_headers("json")["Content-Disposition"], - ) + + # freeze time to ensure filename is deterministic + with freeze_time("2020-01-01T00:00:00Z"): + rv = self.get_assert_metric(uri, "export") + headers = generate_download_headers("json")["Content-Disposition"] + + assert rv.status_code == 200 + assert rv.headers["Content-Disposition"] == headers def test_export_not_found(self): """ diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py index a5ce0b408..9b27f82b1 100644 --- a/tests/databases/api_tests.py +++ b/tests/databases/api_tests.py @@ -18,6 +18,8 @@ """Unit tests for Superset""" import datetime import json +from io import BytesIO +from zipfile import is_zipfile import pandas as pd import prison @@ -801,3 +803,45 @@ class TestDatabaseApi(SupersetTestCase): uri = f"api/v1/database/{database.id}/related_objects/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) + + def test_export_database(self): + """ + Database API: Test export database + """ + self.login(username="admin") + database = get_example_database() + argument = [database.id] + uri = f"api/v1/database/export/?q={prison.dumps(argument)}" + rv = self.client.get(uri) + + assert rv.status_code == 200 + + buf = BytesIO(rv.data) + assert is_zipfile(buf) + + def test_export_database_not_allowed(self): + """ + Database API: Test export database not allowed + """ + self.login(username="gamma") + database = get_example_database() + argument = [database.id] + uri = f"api/v1/database/export/?q={prison.dumps(argument)}" + rv = self.client.get(uri) + + assert rv.status_code == 401 + + def test_export_database_non_existing(self): + """ + Database API: Test export database not allowed + """ + max_id = db.session.query(func.max(Database.id)).scalar() + # id does not exist and we get 404 + invalid_id = max_id + 1 + + self.login(username="admin") + argument = [invalid_id] + uri = f"api/v1/database/export/?q={prison.dumps(argument)}" + rv = self.client.get(uri) + + assert rv.status_code == 404 diff --git a/tests/databases/commands_tests.py b/tests/databases/commands_tests.py new file mode 100644 index 000000000..cb2fcb3a6 --- /dev/null +++ b/tests/databases/commands_tests.py @@ -0,0 +1,266 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import patch + +import yaml + +from superset import db, security_manager +from superset.databases.commands.exceptions import DatabaseNotFoundError +from superset.databases.commands.export import ExportDatabasesCommand +from superset.models.core import Database +from superset.utils.core import backend, get_example_database +from tests.base_tests import SupersetTestCase + + +class TestExportDatabasesCommand(SupersetTestCase): + @patch("superset.security.manager.g") + def test_export_database_command(self, mock_g): + mock_g.user = security_manager.find_user("admin") + + example_db = get_example_database() + command = ExportDatabasesCommand(database_ids=[example_db.id]) + contents = dict(command.run()) + + # TODO: this list shouldn't depend on the order in which unit tests are run + # or on the backend; for now use a stable subset + core_datasets = { + "databases/examples.yaml", + "datasets/energy_usage.yaml", + "datasets/wb_health_population.yaml", + "datasets/birth_names.yaml", + } + expected_extra = { + "engine_params": {}, + "metadata_cache_timeout": {}, + "metadata_params": {}, + "schemas_allowed_for_csv_upload": [], + } + if backend() == "presto": + expected_extra = {"engine_params": {"connect_args": {"poll_interval": 0.1}}} + + assert core_datasets.issubset(set(contents.keys())) + + metadata = yaml.safe_load(contents["databases/examples.yaml"]) + assert metadata == ( + { + "allow_csv_upload": True, + "allow_ctas": True, + "allow_cvas": True, + "allow_run_async": False, + "cache_timeout": None, + "database_name": "examples", + "expose_in_sqllab": True, + "extra": expected_extra, + "sqlalchemy_uri": example_db.sqlalchemy_uri, + "uuid": str(example_db.uuid), + "version": "1.0.0", + } + ) + + metadata = yaml.safe_load(contents["datasets/birth_names.yaml"]) + metadata.pop("uuid") + assert metadata == { + "table_name": "birth_names", + "main_dttm_col": None, + "description": "Adding a DESCRip", + "default_endpoint": "", + "offset": 66, + "cache_timeout": 55, + "schema": "", + "sql": "", + "params": None, + "template_params": None, + "filter_select_enabled": True, + "fetch_values_predicate": None, + "metrics": [ + { + "metric_name": "ratio", + "verbose_name": "Ratio Boys/Girls", + "metric_type": None, + "expression": "sum(sum_boys) / sum(sum_girls)", + "description": "This represents the ratio of boys/girls", + "d3format": ".2%", + "extra": None, + "warning_text": "no warning", + }, + { + "metric_name": "sum__num", + "verbose_name": "Babies", + "metric_type": None, + "expression": "SUM(num)", + "description": "", + "d3format": "", + "extra": None, + "warning_text": "", + }, + { + "metric_name": "count", + "verbose_name": "", + "metric_type": None, + "expression": "count(1)", + "description": None, + "d3format": None, + "extra": None, + "warning_text": None, + }, + ], + "columns": [ + { + "column_name": "num_california", + "verbose_name": None, + "is_dttm": False, + "is_active": None, + "type": "NUMBER", + "groupby": False, + "filterable": False, + "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", + "description": None, + "python_date_format": None, + }, + { + "column_name": "ds", + "verbose_name": "", + "is_dttm": True, + "is_active": None, + "type": "DATETIME", + "groupby": True, + "filterable": True, + "expression": "", + "description": None, + "python_date_format": None, + }, + { + "column_name": "sum_girls", + "verbose_name": None, + "is_dttm": False, + "is_active": None, + "type": "BIGINT(20)", + "groupby": False, + "filterable": False, + "expression": "", + "description": None, + "python_date_format": None, + }, + { + "column_name": "gender", + "verbose_name": None, + "is_dttm": False, + "is_active": None, + "type": "VARCHAR(16)", + "groupby": True, + "filterable": True, + "expression": "", + "description": None, + "python_date_format": None, + }, + { + "column_name": "state", + "verbose_name": None, + "is_dttm": None, + "is_active": None, + "type": "VARCHAR(10)", + "groupby": True, + "filterable": True, + "expression": None, + "description": None, + "python_date_format": None, + }, + { + "column_name": "sum_boys", + "verbose_name": None, + "is_dttm": None, + "is_active": None, + "type": "BIGINT(20)", + "groupby": True, + "filterable": True, + "expression": None, + "description": None, + "python_date_format": None, + }, + { + "column_name": "num", + "verbose_name": None, + "is_dttm": None, + "is_active": None, + "type": "BIGINT(20)", + "groupby": True, + "filterable": True, + "expression": None, + "description": None, + "python_date_format": None, + }, + { + "column_name": "name", + "verbose_name": None, + "is_dttm": None, + "is_active": None, + "type": "VARCHAR(255)", + "groupby": True, + "filterable": True, + "expression": None, + "description": None, + "python_date_format": None, + }, + ], + "version": "1.0.0", + "database_uuid": str(example_db.uuid), + } + + @patch("superset.security.manager.g") + def test_export_database_command_no_access(self, mock_g): + """Test that users can't export databases they don't have access to""" + mock_g.user = security_manager.find_user("gamma") + + example_db = get_example_database() + command = ExportDatabasesCommand(database_ids=[example_db.id]) + contents = command.run() + with self.assertRaises(DatabaseNotFoundError): + next(contents) + + @patch("superset.security.manager.g") + def test_export_database_command_invalid_database(self, mock_g): + """Test that an error is raised when exporting an invalid database""" + mock_g.user = security_manager.find_user("admin") + command = ExportDatabasesCommand(database_ids=[-1]) + contents = command.run() + with self.assertRaises(DatabaseNotFoundError): + next(contents) + + @patch("superset.security.manager.g") + def test_export_database_command_key_order(self, mock_g): + """Test that they keys in the YAML have the same order as export_fields""" + mock_g.user = security_manager.find_user("admin") + + example_db = get_example_database() + command = ExportDatabasesCommand(database_ids=[example_db.id]) + contents = dict(command.run()) + + metadata = yaml.safe_load(contents["databases/examples.yaml"]) + assert list(metadata.keys()) == [ + "database_name", + "sqlalchemy_uri", + "cache_timeout", + "expose_in_sqllab", + "allow_run_async", + "allow_ctas", + "allow_cvas", + "allow_csv_upload", + "extra", + "uuid", + "version", + ] diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index a2ff652fa..cd62a2ae8 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -216,10 +216,9 @@ class TestDatasetApi(SupersetTestCase): "admin_database", "information_schema", "public", - "superset", ] expected_response = { - "count": 5, + "count": 4, "result": [{"text": val, "value": val} for val in schema_values], } self.login(username="admin") @@ -243,14 +242,14 @@ class TestDatasetApi(SupersetTestCase): query_parameter = {"page": 0, "page_size": 1} pg_test_query_parameter( - query_parameter, {"count": 5, "result": [{"text": "", "value": ""}]}, + query_parameter, {"count": 4, "result": [{"text": "", "value": ""}]}, ) query_parameter = {"page": 1, "page_size": 1} pg_test_query_parameter( query_parameter, { - "count": 5, + "count": 4, "result": [{"text": "admin_database", "value": "admin_database"}], }, ) diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index de46e1c55..9cdea16ef 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -20,6 +20,7 @@ from unittest.mock import patch import pytest import tests.test_app +from superset import db from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.db_engine_specs.druid import DruidEngineSpec from superset.exceptions import QueryObjectValidationError @@ -86,49 +87,54 @@ class TestDatabaseModel(SupersetTestCase): } # Table with Jinja callable. - table = SqlaTable( + table1 = SqlaTable( table_name="test_has_extra_cache_keys_table", sql="SELECT '{{ current_username() }}' as user", database=get_example_database(), ) query_obj = dict(**base_query_obj, extras={}) - extra_cache_keys = table.get_extra_cache_keys(query_obj) - self.assertTrue(table.has_extra_cache_key_calls(query_obj)) + extra_cache_keys = table1.get_extra_cache_keys(query_obj) + self.assertTrue(table1.has_extra_cache_key_calls(query_obj)) assert extra_cache_keys == ["abc"] # Table with Jinja callable disabled. - table = SqlaTable( + table2 = SqlaTable( table_name="test_has_extra_cache_keys_disabled_table", sql="SELECT '{{ current_username(False) }}' as user", database=get_example_database(), ) query_obj = dict(**base_query_obj, extras={}) - extra_cache_keys = table.get_extra_cache_keys(query_obj) - self.assertTrue(table.has_extra_cache_key_calls(query_obj)) + extra_cache_keys = table2.get_extra_cache_keys(query_obj) + self.assertTrue(table2.has_extra_cache_key_calls(query_obj)) self.assertListEqual(extra_cache_keys, []) # Table with no Jinja callable. query = "SELECT 'abc' as user" - table = SqlaTable( + table3 = SqlaTable( table_name="test_has_no_extra_cache_keys_table", sql=query, database=get_example_database(), ) query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"}) - extra_cache_keys = table.get_extra_cache_keys(query_obj) - self.assertFalse(table.has_extra_cache_key_calls(query_obj)) + extra_cache_keys = table3.get_extra_cache_keys(query_obj) + self.assertFalse(table3.has_extra_cache_key_calls(query_obj)) self.assertListEqual(extra_cache_keys, []) # With Jinja callable in SQL expression. query_obj = dict( **base_query_obj, extras={"where": "(user != '{{ current_username() }}')"} ) - extra_cache_keys = table.get_extra_cache_keys(query_obj) - self.assertTrue(table.has_extra_cache_key_calls(query_obj)) + extra_cache_keys = table3.get_extra_cache_keys(query_obj) + self.assertTrue(table3.has_extra_cache_key_calls(query_obj)) assert extra_cache_keys == ["abc"] + # Cleanup + for table in [table1, table2, table3]: + db.session.delete(table) + db.session.commit() + def test_where_operators(self): class FilterTestCase(NamedTuple): operator: str diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index cff4796a5..97e4f3fe0 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -384,6 +384,10 @@ class TestSqlLab(SupersetTestCase): view_menu = security_manager.find_view_menu(table.get_perm()) assert view_menu is not None + # Cleanup + db.session.delete(table) + db.session.commit() + def test_sqllab_viz_bad_payload(self): self.login("admin") payload = {