feat: export databases as a ZIP bundle (#11229)
* Export databases as Zip file * Fix tests * Address comments * Implement mulexport for database * Fix lint * Fix lint
This commit is contained in:
parent
8863c939ad
commit
94e23bfc82
|
|
@ -16,8 +16,9 @@
|
|||
#
|
||||
-r development.in
|
||||
-r integration.in
|
||||
flask-testing
|
||||
docker
|
||||
flask-testing
|
||||
freezegun
|
||||
ipdb
|
||||
# pinning ipython as pip-compile-multi was bringing higher version
|
||||
# of the ipython that was not found in CI
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# SHA1:0e68e30f4e1bc76d0ec05267a1e38451c3901384
|
||||
# SHA1:b16b83f856b2dbc53535f71414f5c3e8dfa838e0
|
||||
#
|
||||
# This file is autogenerated by pip-compile-multi
|
||||
# To update, run:
|
||||
|
|
@ -14,31 +14,30 @@ backcall==0.2.0 # via ipython
|
|||
coverage==5.3 # via pytest-cov
|
||||
docker==4.3.1 # via -r requirements/testing.in
|
||||
flask-testing==0.8.0 # via -r requirements/testing.in
|
||||
freezegun==1.0.0 # via -r requirements/testing.in
|
||||
iniconfig==1.0.1 # via pytest
|
||||
ipdb==0.13.3 # via -r requirements/testing.in
|
||||
ipdb==0.13.4 # via -r requirements/testing.in
|
||||
ipython-genutils==0.2.0 # via traitlets
|
||||
ipython==7.16.1 # via -r requirements/testing.in, ipdb
|
||||
isort==5.5.3 # via pylint
|
||||
isort==5.6.4 # via pylint
|
||||
jedi==0.17.2 # via ipython
|
||||
lazy-object-proxy==1.4.3 # via astroid
|
||||
mccabe==0.6.1 # via pylint
|
||||
more-itertools==8.5.0 # via pytest
|
||||
openapi-spec-validator==0.2.9 # via -r requirements/testing.in
|
||||
parameterized==0.7.4 # via -r requirements/testing.in
|
||||
parso==0.7.1 # via jedi
|
||||
pexpect==4.8.0 # via ipython
|
||||
pickleshare==0.7.5 # via ipython
|
||||
prompt-toolkit==3.0.7 # via ipython
|
||||
prompt-toolkit==3.0.8 # via ipython
|
||||
ptyprocess==0.6.0 # via pexpect
|
||||
pygments==2.7.1 # via ipython
|
||||
pyhive[hive,presto]==0.6.3 # via -r requirements/development.in, -r requirements/testing.in
|
||||
pylint==2.6.0 # via -r requirements/testing.in
|
||||
pytest-cov==2.10.1 # via -r requirements/testing.in
|
||||
pytest==6.0.2 # via -r requirements/testing.in, pytest-cov
|
||||
pytest==6.1.1 # via -r requirements/testing.in, pytest-cov
|
||||
redis==3.5.3 # via -r requirements/testing.in
|
||||
statsd==3.3.0 # via -r requirements/testing.in
|
||||
traitlets==5.0.4 # via ipython
|
||||
typed-ast==1.4.1 # via astroid
|
||||
wcwidth==0.2.5 # via prompt-toolkit
|
||||
websocket-client==0.57.0 # via docker
|
||||
wrapt==1.12.1 # via astroid
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ combine_as_imports = true
|
|||
include_trailing_comma = true
|
||||
line_length = 88
|
||||
known_first_party = superset
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
|
||||
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
|
||||
multi_line_output = 3
|
||||
order_by_type = false
|
||||
|
||||
|
|
|
|||
|
|
@ -15,9 +15,12 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional
|
||||
from zipfile import ZipFile
|
||||
|
||||
from flask import g, request, Response
|
||||
from flask import g, request, Response, send_file
|
||||
from flask_appbuilder.api import expose, protect, rison, safe
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from flask_babel import gettext as _
|
||||
|
|
@ -43,6 +46,7 @@ from superset.databases.commands.exceptions import (
|
|||
DatabaseSecurityUnsafeError,
|
||||
DatabaseUpdateFailedError,
|
||||
)
|
||||
from superset.databases.commands.export import ExportDatabasesCommand
|
||||
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
|
||||
from superset.databases.commands.update import UpdateDatabaseCommand
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
|
|
@ -54,6 +58,7 @@ from superset.databases.schemas import (
|
|||
DatabasePutSchema,
|
||||
DatabaseRelatedObjectsResponse,
|
||||
DatabaseTestConnectionSchema,
|
||||
get_export_ids_schema,
|
||||
SchemasResponseSchema,
|
||||
SelectStarResponseSchema,
|
||||
TableMetadataResponseSchema,
|
||||
|
|
@ -72,6 +77,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
datamodel = SQLAInterface(Database)
|
||||
|
||||
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
|
||||
RouteMethod.EXPORT,
|
||||
"table_metadata",
|
||||
"select_star",
|
||||
"schemas",
|
||||
|
|
@ -653,3 +659,61 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
charts={"count": len(charts), "result": charts},
|
||||
dashboards={"count": len(dashboards), "result": dashboards},
|
||||
)
|
||||
|
||||
@expose("/export/", methods=["GET"])
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@rison(get_export_ids_schema)
|
||||
def export(self, **kwargs: Any) -> Response:
|
||||
"""Export database(s) with associated datasets
|
||||
---
|
||||
get:
|
||||
description: Download database(s) and associated dataset(s) as a zip file
|
||||
parameters:
|
||||
- in: query
|
||||
name: q
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: array
|
||||
items:
|
||||
type: integer
|
||||
responses:
|
||||
200:
|
||||
description: A zip file with database(s) and dataset(s) as YAML
|
||||
content:
|
||||
application/zip:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
requested_ids = kwargs["rison"]
|
||||
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||
root = f"database_export_{timestamp}"
|
||||
filename = f"{root}.zip"
|
||||
|
||||
buf = BytesIO()
|
||||
with ZipFile(buf, "w") as bundle:
|
||||
try:
|
||||
for file_name, file_content in ExportDatabasesCommand(
|
||||
requested_ids
|
||||
).run():
|
||||
with bundle.open(f"{root}/{file_name}", "w") as fp:
|
||||
fp.write(file_content.encode())
|
||||
except DatabaseNotFoundError:
|
||||
return self.response_404()
|
||||
buf.seek(0)
|
||||
|
||||
return send_file(
|
||||
buf,
|
||||
mimetype="application/zip",
|
||||
as_attachment=True,
|
||||
attachment_filename=filename,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from superset.security.analytics_db_safety import DBSecurityException
|
|||
|
||||
|
||||
class DatabaseInvalidError(CommandInvalidError):
|
||||
message = _("Dashboard parameters are invalid.")
|
||||
message = _("Database parameters are invalid.")
|
||||
|
||||
|
||||
class DatabaseExistsValidationError(ValidationError):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
|
||||
import json
|
||||
from typing import Iterator, List, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.databases.commands.exceptions import DatabaseNotFoundError
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
class ExportDatabasesCommand(BaseCommand):
|
||||
def __init__(self, database_ids: List[int]):
|
||||
self.database_ids = database_ids
|
||||
|
||||
# this will be set when calling validate()
|
||||
self._models: List[Database] = []
|
||||
|
||||
@staticmethod
|
||||
def export_database(database: Database) -> Iterator[Tuple[str, str]]:
|
||||
name = sanitize(database.database_name)
|
||||
file_name = f"databases/{name}.yaml"
|
||||
|
||||
payload = database.export_to_dict(
|
||||
recursive=False,
|
||||
include_parent_ref=False,
|
||||
include_defaults=True,
|
||||
export_uuids=True,
|
||||
)
|
||||
# TODO (betodealmeida): move this logic to export_to_dict once this
|
||||
# becomes the default export endpoint
|
||||
if "extra" in payload:
|
||||
try:
|
||||
payload["extra"] = json.loads(payload["extra"])
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
|
||||
payload["version"] = IMPORT_EXPORT_VERSION
|
||||
|
||||
file_content = yaml.safe_dump(payload, sort_keys=False)
|
||||
yield file_name, file_content
|
||||
|
||||
# TODO (betodealmeida): reuse logic from ExportDatasetCommand once
|
||||
# it's implemented
|
||||
for dataset in database.tables:
|
||||
name = sanitize(dataset.table_name)
|
||||
file_name = f"datasets/{name}.yaml"
|
||||
|
||||
payload = dataset.export_to_dict(
|
||||
recursive=True,
|
||||
include_parent_ref=False,
|
||||
include_defaults=True,
|
||||
export_uuids=True,
|
||||
)
|
||||
payload["version"] = IMPORT_EXPORT_VERSION
|
||||
payload["database_uuid"] = str(database.uuid)
|
||||
|
||||
file_content = yaml.safe_dump(payload, sort_keys=False)
|
||||
yield file_name, file_content
|
||||
|
||||
def run(self) -> Iterator[Tuple[str, str]]:
|
||||
self.validate()
|
||||
|
||||
for database in self._models:
|
||||
yield from self.export_database(database)
|
||||
|
||||
def validate(self) -> None:
|
||||
self._models = DatabaseDAO.find_by_ids(self.database_ids)
|
||||
if len(self._models) != len(self.database_ids):
|
||||
raise DatabaseNotFoundError()
|
||||
|
|
@ -109,6 +109,7 @@ extra_description = markdown(
|
|||
"whether or not the Explore button in SQL Lab results is shown.",
|
||||
True,
|
||||
)
|
||||
get_export_ids_schema = {"type": "array", "items": {"type": "integer"}}
|
||||
sqlalchemy_uri_description = markdown(
|
||||
"Refer to the "
|
||||
"[SqlAlchemy docs]"
|
||||
|
|
|
|||
|
|
@ -52,6 +52,23 @@ def json_to_dict(json_str: str) -> Dict[Any, Any]:
|
|||
return {}
|
||||
|
||||
|
||||
def convert_uuids(obj: Any) -> Any:
|
||||
"""
|
||||
Convert UUID objects to str so we can use yaml.safe_dump
|
||||
"""
|
||||
if isinstance(obj, uuid.UUID):
|
||||
return str(obj)
|
||||
|
||||
if isinstance(obj, list):
|
||||
return [convert_uuids(el) for el in obj]
|
||||
|
||||
if isinstance(obj, dict):
|
||||
return {k: convert_uuids(v) for k, v in obj.items()}
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
# TODO (betodealmeida): rename to ImportExportMixin
|
||||
class ImportMixin:
|
||||
uuid = sa.Column(
|
||||
UUIDType(binary=True), primary_key=False, unique=True, default=uuid.uuid4
|
||||
|
|
@ -247,8 +264,15 @@ class ImportMixin:
|
|||
recursive: bool = True,
|
||||
include_parent_ref: bool = False,
|
||||
include_defaults: bool = False,
|
||||
export_uuids: bool = False,
|
||||
) -> Dict[Any, Any]:
|
||||
"""Export obj to dictionary"""
|
||||
export_fields = set(self.export_fields)
|
||||
if export_uuids:
|
||||
export_fields.add("uuid")
|
||||
if "id" in export_fields:
|
||||
export_fields.remove("id")
|
||||
|
||||
cls = self.__class__
|
||||
parent_excludes = set()
|
||||
if recursive and not include_parent_ref:
|
||||
|
|
@ -259,7 +283,7 @@ class ImportMixin:
|
|||
c.name: getattr(self, c.name)
|
||||
for c in cls.__table__.columns # type: ignore
|
||||
if (
|
||||
c.name in self.export_fields
|
||||
c.name in export_fields
|
||||
and c.name not in parent_excludes
|
||||
and (
|
||||
include_defaults
|
||||
|
|
@ -270,6 +294,13 @@ class ImportMixin:
|
|||
)
|
||||
)
|
||||
}
|
||||
|
||||
# sort according to export_fields using DSU (decorate, sort, undecorate)
|
||||
order = {field: i for i, field in enumerate(self.export_fields)}
|
||||
decorated_keys = [(order.get(k, len(order)), k) for k in dict_rep]
|
||||
decorated_keys.sort()
|
||||
dict_rep = {k: dict_rep[k] for _, k in decorated_keys}
|
||||
|
||||
if recursive:
|
||||
for cld in self.export_children:
|
||||
# sorting to make lists of children stable
|
||||
|
|
@ -285,7 +316,7 @@ class ImportMixin:
|
|||
key=lambda k: sorted(str(k.items())),
|
||||
)
|
||||
|
||||
return dict_rep
|
||||
return convert_uuids(dict_rep)
|
||||
|
||||
def override(self, obj: Any) -> None:
|
||||
"""Overrides the plain fields of the dashboard."""
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -22,6 +24,7 @@ from sqlalchemy.orm import Session
|
|||
from superset.connectors.druid.models import DruidCluster
|
||||
from superset.models.core import Database
|
||||
|
||||
IMPORT_EXPORT_VERSION = "1.0.0"
|
||||
DATABASES_KEY = "databases"
|
||||
DRUID_CLUSTERS_KEY = "druid_clusters"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -95,3 +98,18 @@ def import_from_dict(
|
|||
session.commit()
|
||||
else:
|
||||
logger.info("Supplied object is not a dictionary.")
|
||||
|
||||
|
||||
def strip_accents(text: str) -> str:
|
||||
text = unicodedata.normalize("NFD", text).encode("ascii", "ignore").decode("utf-8")
|
||||
|
||||
return str(text)
|
||||
|
||||
|
||||
def sanitize(name: str) -> str:
|
||||
"""Sanitize a post title into a directory name."""
|
||||
name = name.lower().replace(" ", "_")
|
||||
name = re.sub(r"[^\w]", "", name)
|
||||
name = strip_accents(name)
|
||||
|
||||
return name
|
||||
|
|
|
|||
|
|
@ -962,7 +962,6 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
}
|
||||
]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
print(rv.data)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import prison
|
|||
from sqlalchemy.sql import func
|
||||
|
||||
import tests.test_app
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy import and_
|
||||
from superset import db, security_manager
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
|
@ -955,12 +956,14 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
self.login(username="admin")
|
||||
argument = [1, 2]
|
||||
uri = f"api/v1/dashboard/export/?q={prison.dumps(argument)}"
|
||||
rv = self.get_assert_metric(uri, "export")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(
|
||||
rv.headers["Content-Disposition"],
|
||||
generate_download_headers("json")["Content-Disposition"],
|
||||
)
|
||||
|
||||
# freeze time to ensure filename is deterministic
|
||||
with freeze_time("2020-01-01T00:00:00Z"):
|
||||
rv = self.get_assert_metric(uri, "export")
|
||||
headers = generate_download_headers("json")["Content-Disposition"]
|
||||
|
||||
assert rv.status_code == 200
|
||||
assert rv.headers["Content-Disposition"] == headers
|
||||
|
||||
def test_export_not_found(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@
|
|||
"""Unit tests for Superset"""
|
||||
import datetime
|
||||
import json
|
||||
from io import BytesIO
|
||||
from zipfile import is_zipfile
|
||||
|
||||
import pandas as pd
|
||||
import prison
|
||||
|
|
@ -801,3 +803,45 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
uri = f"api/v1/database/{database.id}/related_objects/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_export_database(self):
|
||||
"""
|
||||
Database API: Test export database
|
||||
"""
|
||||
self.login(username="admin")
|
||||
database = get_example_database()
|
||||
argument = [database.id]
|
||||
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
|
||||
rv = self.client.get(uri)
|
||||
|
||||
assert rv.status_code == 200
|
||||
|
||||
buf = BytesIO(rv.data)
|
||||
assert is_zipfile(buf)
|
||||
|
||||
def test_export_database_not_allowed(self):
|
||||
"""
|
||||
Database API: Test export database not allowed
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
database = get_example_database()
|
||||
argument = [database.id]
|
||||
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
|
||||
rv = self.client.get(uri)
|
||||
|
||||
assert rv.status_code == 401
|
||||
|
||||
def test_export_database_non_existing(self):
|
||||
"""
|
||||
Database API: Test export database not allowed
|
||||
"""
|
||||
max_id = db.session.query(func.max(Database.id)).scalar()
|
||||
# id does not exist and we get 404
|
||||
invalid_id = max_id + 1
|
||||
|
||||
self.login(username="admin")
|
||||
argument = [invalid_id]
|
||||
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
|
||||
rv = self.client.get(uri)
|
||||
|
||||
assert rv.status_code == 404
|
||||
|
|
|
|||
|
|
@ -0,0 +1,266 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.databases.commands.exceptions import DatabaseNotFoundError
|
||||
from superset.databases.commands.export import ExportDatabasesCommand
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import backend, get_example_database
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class TestExportDatabasesCommand(SupersetTestCase):
|
||||
@patch("superset.security.manager.g")
|
||||
def test_export_database_command(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
|
||||
example_db = get_example_database()
|
||||
command = ExportDatabasesCommand(database_ids=[example_db.id])
|
||||
contents = dict(command.run())
|
||||
|
||||
# TODO: this list shouldn't depend on the order in which unit tests are run
|
||||
# or on the backend; for now use a stable subset
|
||||
core_datasets = {
|
||||
"databases/examples.yaml",
|
||||
"datasets/energy_usage.yaml",
|
||||
"datasets/wb_health_population.yaml",
|
||||
"datasets/birth_names.yaml",
|
||||
}
|
||||
expected_extra = {
|
||||
"engine_params": {},
|
||||
"metadata_cache_timeout": {},
|
||||
"metadata_params": {},
|
||||
"schemas_allowed_for_csv_upload": [],
|
||||
}
|
||||
if backend() == "presto":
|
||||
expected_extra = {"engine_params": {"connect_args": {"poll_interval": 0.1}}}
|
||||
|
||||
assert core_datasets.issubset(set(contents.keys()))
|
||||
|
||||
metadata = yaml.safe_load(contents["databases/examples.yaml"])
|
||||
assert metadata == (
|
||||
{
|
||||
"allow_csv_upload": True,
|
||||
"allow_ctas": True,
|
||||
"allow_cvas": True,
|
||||
"allow_run_async": False,
|
||||
"cache_timeout": None,
|
||||
"database_name": "examples",
|
||||
"expose_in_sqllab": True,
|
||||
"extra": expected_extra,
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri,
|
||||
"uuid": str(example_db.uuid),
|
||||
"version": "1.0.0",
|
||||
}
|
||||
)
|
||||
|
||||
metadata = yaml.safe_load(contents["datasets/birth_names.yaml"])
|
||||
metadata.pop("uuid")
|
||||
assert metadata == {
|
||||
"table_name": "birth_names",
|
||||
"main_dttm_col": None,
|
||||
"description": "Adding a DESCRip",
|
||||
"default_endpoint": "",
|
||||
"offset": 66,
|
||||
"cache_timeout": 55,
|
||||
"schema": "",
|
||||
"sql": "",
|
||||
"params": None,
|
||||
"template_params": None,
|
||||
"filter_select_enabled": True,
|
||||
"fetch_values_predicate": None,
|
||||
"metrics": [
|
||||
{
|
||||
"metric_name": "ratio",
|
||||
"verbose_name": "Ratio Boys/Girls",
|
||||
"metric_type": None,
|
||||
"expression": "sum(sum_boys) / sum(sum_girls)",
|
||||
"description": "This represents the ratio of boys/girls",
|
||||
"d3format": ".2%",
|
||||
"extra": None,
|
||||
"warning_text": "no warning",
|
||||
},
|
||||
{
|
||||
"metric_name": "sum__num",
|
||||
"verbose_name": "Babies",
|
||||
"metric_type": None,
|
||||
"expression": "SUM(num)",
|
||||
"description": "",
|
||||
"d3format": "",
|
||||
"extra": None,
|
||||
"warning_text": "",
|
||||
},
|
||||
{
|
||||
"metric_name": "count",
|
||||
"verbose_name": "",
|
||||
"metric_type": None,
|
||||
"expression": "count(1)",
|
||||
"description": None,
|
||||
"d3format": None,
|
||||
"extra": None,
|
||||
"warning_text": None,
|
||||
},
|
||||
],
|
||||
"columns": [
|
||||
{
|
||||
"column_name": "num_california",
|
||||
"verbose_name": None,
|
||||
"is_dttm": False,
|
||||
"is_active": None,
|
||||
"type": "NUMBER",
|
||||
"groupby": False,
|
||||
"filterable": False,
|
||||
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
|
||||
"description": None,
|
||||
"python_date_format": None,
|
||||
},
|
||||
{
|
||||
"column_name": "ds",
|
||||
"verbose_name": "",
|
||||
"is_dttm": True,
|
||||
"is_active": None,
|
||||
"type": "DATETIME",
|
||||
"groupby": True,
|
||||
"filterable": True,
|
||||
"expression": "",
|
||||
"description": None,
|
||||
"python_date_format": None,
|
||||
},
|
||||
{
|
||||
"column_name": "sum_girls",
|
||||
"verbose_name": None,
|
||||
"is_dttm": False,
|
||||
"is_active": None,
|
||||
"type": "BIGINT(20)",
|
||||
"groupby": False,
|
||||
"filterable": False,
|
||||
"expression": "",
|
||||
"description": None,
|
||||
"python_date_format": None,
|
||||
},
|
||||
{
|
||||
"column_name": "gender",
|
||||
"verbose_name": None,
|
||||
"is_dttm": False,
|
||||
"is_active": None,
|
||||
"type": "VARCHAR(16)",
|
||||
"groupby": True,
|
||||
"filterable": True,
|
||||
"expression": "",
|
||||
"description": None,
|
||||
"python_date_format": None,
|
||||
},
|
||||
{
|
||||
"column_name": "state",
|
||||
"verbose_name": None,
|
||||
"is_dttm": None,
|
||||
"is_active": None,
|
||||
"type": "VARCHAR(10)",
|
||||
"groupby": True,
|
||||
"filterable": True,
|
||||
"expression": None,
|
||||
"description": None,
|
||||
"python_date_format": None,
|
||||
},
|
||||
{
|
||||
"column_name": "sum_boys",
|
||||
"verbose_name": None,
|
||||
"is_dttm": None,
|
||||
"is_active": None,
|
||||
"type": "BIGINT(20)",
|
||||
"groupby": True,
|
||||
"filterable": True,
|
||||
"expression": None,
|
||||
"description": None,
|
||||
"python_date_format": None,
|
||||
},
|
||||
{
|
||||
"column_name": "num",
|
||||
"verbose_name": None,
|
||||
"is_dttm": None,
|
||||
"is_active": None,
|
||||
"type": "BIGINT(20)",
|
||||
"groupby": True,
|
||||
"filterable": True,
|
||||
"expression": None,
|
||||
"description": None,
|
||||
"python_date_format": None,
|
||||
},
|
||||
{
|
||||
"column_name": "name",
|
||||
"verbose_name": None,
|
||||
"is_dttm": None,
|
||||
"is_active": None,
|
||||
"type": "VARCHAR(255)",
|
||||
"groupby": True,
|
||||
"filterable": True,
|
||||
"expression": None,
|
||||
"description": None,
|
||||
"python_date_format": None,
|
||||
},
|
||||
],
|
||||
"version": "1.0.0",
|
||||
"database_uuid": str(example_db.uuid),
|
||||
}
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
def test_export_database_command_no_access(self, mock_g):
|
||||
"""Test that users can't export databases they don't have access to"""
|
||||
mock_g.user = security_manager.find_user("gamma")
|
||||
|
||||
example_db = get_example_database()
|
||||
command = ExportDatabasesCommand(database_ids=[example_db.id])
|
||||
contents = command.run()
|
||||
with self.assertRaises(DatabaseNotFoundError):
|
||||
next(contents)
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
def test_export_database_command_invalid_database(self, mock_g):
|
||||
"""Test that an error is raised when exporting an invalid database"""
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
command = ExportDatabasesCommand(database_ids=[-1])
|
||||
contents = command.run()
|
||||
with self.assertRaises(DatabaseNotFoundError):
|
||||
next(contents)
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
def test_export_database_command_key_order(self, mock_g):
|
||||
"""Test that they keys in the YAML have the same order as export_fields"""
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
|
||||
example_db = get_example_database()
|
||||
command = ExportDatabasesCommand(database_ids=[example_db.id])
|
||||
contents = dict(command.run())
|
||||
|
||||
metadata = yaml.safe_load(contents["databases/examples.yaml"])
|
||||
assert list(metadata.keys()) == [
|
||||
"database_name",
|
||||
"sqlalchemy_uri",
|
||||
"cache_timeout",
|
||||
"expose_in_sqllab",
|
||||
"allow_run_async",
|
||||
"allow_ctas",
|
||||
"allow_cvas",
|
||||
"allow_csv_upload",
|
||||
"extra",
|
||||
"uuid",
|
||||
"version",
|
||||
]
|
||||
|
|
@ -216,10 +216,9 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"admin_database",
|
||||
"information_schema",
|
||||
"public",
|
||||
"superset",
|
||||
]
|
||||
expected_response = {
|
||||
"count": 5,
|
||||
"count": 4,
|
||||
"result": [{"text": val, "value": val} for val in schema_values],
|
||||
}
|
||||
self.login(username="admin")
|
||||
|
|
@ -243,14 +242,14 @@ class TestDatasetApi(SupersetTestCase):
|
|||
|
||||
query_parameter = {"page": 0, "page_size": 1}
|
||||
pg_test_query_parameter(
|
||||
query_parameter, {"count": 5, "result": [{"text": "", "value": ""}]},
|
||||
query_parameter, {"count": 4, "result": [{"text": "", "value": ""}]},
|
||||
)
|
||||
|
||||
query_parameter = {"page": 1, "page_size": 1}
|
||||
pg_test_query_parameter(
|
||||
query_parameter,
|
||||
{
|
||||
"count": 5,
|
||||
"count": 4,
|
||||
"result": [{"text": "admin_database", "value": "admin_database"}],
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
|
||||
import tests.test_app
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.db_engine_specs.druid import DruidEngineSpec
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
|
|
@ -86,49 +87,54 @@ class TestDatabaseModel(SupersetTestCase):
|
|||
}
|
||||
|
||||
# Table with Jinja callable.
|
||||
table = SqlaTable(
|
||||
table1 = SqlaTable(
|
||||
table_name="test_has_extra_cache_keys_table",
|
||||
sql="SELECT '{{ current_username() }}' as user",
|
||||
database=get_example_database(),
|
||||
)
|
||||
|
||||
query_obj = dict(**base_query_obj, extras={})
|
||||
extra_cache_keys = table.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table.has_extra_cache_key_calls(query_obj))
|
||||
extra_cache_keys = table1.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
|
||||
assert extra_cache_keys == ["abc"]
|
||||
|
||||
# Table with Jinja callable disabled.
|
||||
table = SqlaTable(
|
||||
table2 = SqlaTable(
|
||||
table_name="test_has_extra_cache_keys_disabled_table",
|
||||
sql="SELECT '{{ current_username(False) }}' as user",
|
||||
database=get_example_database(),
|
||||
)
|
||||
query_obj = dict(**base_query_obj, extras={})
|
||||
extra_cache_keys = table.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table.has_extra_cache_key_calls(query_obj))
|
||||
extra_cache_keys = table2.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table2.has_extra_cache_key_calls(query_obj))
|
||||
self.assertListEqual(extra_cache_keys, [])
|
||||
|
||||
# Table with no Jinja callable.
|
||||
query = "SELECT 'abc' as user"
|
||||
table = SqlaTable(
|
||||
table3 = SqlaTable(
|
||||
table_name="test_has_no_extra_cache_keys_table",
|
||||
sql=query,
|
||||
database=get_example_database(),
|
||||
)
|
||||
|
||||
query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"})
|
||||
extra_cache_keys = table.get_extra_cache_keys(query_obj)
|
||||
self.assertFalse(table.has_extra_cache_key_calls(query_obj))
|
||||
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
|
||||
self.assertFalse(table3.has_extra_cache_key_calls(query_obj))
|
||||
self.assertListEqual(extra_cache_keys, [])
|
||||
|
||||
# With Jinja callable in SQL expression.
|
||||
query_obj = dict(
|
||||
**base_query_obj, extras={"where": "(user != '{{ current_username() }}')"}
|
||||
)
|
||||
extra_cache_keys = table.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table.has_extra_cache_key_calls(query_obj))
|
||||
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
|
||||
self.assertTrue(table3.has_extra_cache_key_calls(query_obj))
|
||||
assert extra_cache_keys == ["abc"]
|
||||
|
||||
# Cleanup
|
||||
for table in [table1, table2, table3]:
|
||||
db.session.delete(table)
|
||||
db.session.commit()
|
||||
|
||||
def test_where_operators(self):
|
||||
class FilterTestCase(NamedTuple):
|
||||
operator: str
|
||||
|
|
|
|||
|
|
@ -384,6 +384,10 @@ class TestSqlLab(SupersetTestCase):
|
|||
view_menu = security_manager.find_view_menu(table.get_perm())
|
||||
assert view_menu is not None
|
||||
|
||||
# Cleanup
|
||||
db.session.delete(table)
|
||||
db.session.commit()
|
||||
|
||||
def test_sqllab_viz_bad_payload(self):
|
||||
self.login("admin")
|
||||
payload = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue