feat: export databases as a ZIP bundle (#11229)

* Export databases as Zip file

* Fix tests

* Address comments

* Implement mulexport for database

* Fix lint

* Fix lint
This commit is contained in:
Beto Dealmeida 2020-10-16 11:10:39 -07:00 committed by GitHub
parent 8863c939ad
commit 94e23bfc82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 559 additions and 35 deletions

View File

@ -16,8 +16,9 @@
#
-r development.in
-r integration.in
flask-testing
docker
flask-testing
freezegun
ipdb
# pinning ipython as pip-compile-multi was bringing higher version
# of the ipython that was not found in CI

View File

@ -1,4 +1,4 @@
# SHA1:0e68e30f4e1bc76d0ec05267a1e38451c3901384
# SHA1:b16b83f856b2dbc53535f71414f5c3e8dfa838e0
#
# This file is autogenerated by pip-compile-multi
# To update, run:
@ -14,31 +14,30 @@ backcall==0.2.0 # via ipython
coverage==5.3 # via pytest-cov
docker==4.3.1 # via -r requirements/testing.in
flask-testing==0.8.0 # via -r requirements/testing.in
freezegun==1.0.0 # via -r requirements/testing.in
iniconfig==1.0.1 # via pytest
ipdb==0.13.3 # via -r requirements/testing.in
ipdb==0.13.4 # via -r requirements/testing.in
ipython-genutils==0.2.0 # via traitlets
ipython==7.16.1 # via -r requirements/testing.in, ipdb
isort==5.5.3 # via pylint
isort==5.6.4 # via pylint
jedi==0.17.2 # via ipython
lazy-object-proxy==1.4.3 # via astroid
mccabe==0.6.1 # via pylint
more-itertools==8.5.0 # via pytest
openapi-spec-validator==0.2.9 # via -r requirements/testing.in
parameterized==0.7.4 # via -r requirements/testing.in
parso==0.7.1 # via jedi
pexpect==4.8.0 # via ipython
pickleshare==0.7.5 # via ipython
prompt-toolkit==3.0.7 # via ipython
prompt-toolkit==3.0.8 # via ipython
ptyprocess==0.6.0 # via pexpect
pygments==2.7.1 # via ipython
pyhive[hive,presto]==0.6.3 # via -r requirements/development.in, -r requirements/testing.in
pylint==2.6.0 # via -r requirements/testing.in
pytest-cov==2.10.1 # via -r requirements/testing.in
pytest==6.0.2 # via -r requirements/testing.in, pytest-cov
pytest==6.1.1 # via -r requirements/testing.in, pytest-cov
redis==3.5.3 # via -r requirements/testing.in
statsd==3.3.0 # via -r requirements/testing.in
traitlets==5.0.4 # via ipython
typed-ast==1.4.1 # via astroid
wcwidth==0.2.5 # via prompt-toolkit
websocket-client==0.57.0 # via docker
wrapt==1.12.1 # via astroid

View File

@ -30,7 +30,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytest,pytz,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

View File

@ -15,9 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from io import BytesIO
from typing import Any, Optional
from zipfile import ZipFile
from flask import g, request, Response
from flask import g, request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import gettext as _
@ -43,6 +46,7 @@ from superset.databases.commands.exceptions import (
DatabaseSecurityUnsafeError,
DatabaseUpdateFailedError,
)
from superset.databases.commands.export import ExportDatabasesCommand
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.commands.update import UpdateDatabaseCommand
from superset.databases.dao import DatabaseDAO
@ -54,6 +58,7 @@ from superset.databases.schemas import (
DatabasePutSchema,
DatabaseRelatedObjectsResponse,
DatabaseTestConnectionSchema,
get_export_ids_schema,
SchemasResponseSchema,
SelectStarResponseSchema,
TableMetadataResponseSchema,
@ -72,6 +77,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(Database)
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
RouteMethod.EXPORT,
"table_metadata",
"select_star",
"schemas",
@ -653,3 +659,61 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
charts={"count": len(charts), "result": charts},
dashboards={"count": len(dashboards), "result": dashboards},
)
@expose("/export/", methods=["GET"])
@protect()
@safe
@statsd_metrics
@rison(get_export_ids_schema)
def export(self, **kwargs: Any) -> Response:
"""Export database(s) with associated datasets
---
get:
description: Download database(s) and associated dataset(s) as a zip file
parameters:
- in: query
name: q
content:
application/json:
schema:
type: array
items:
type: integer
responses:
200:
description: A zip file with database(s) and dataset(s) as YAML
content:
application/zip:
schema:
type: string
format: binary
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
requested_ids = kwargs["rison"]
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
root = f"database_export_{timestamp}"
filename = f"{root}.zip"
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
try:
for file_name, file_content in ExportDatabasesCommand(
requested_ids
).run():
with bundle.open(f"{root}/{file_name}", "w") as fp:
fp.write(file_content.encode())
except DatabaseNotFoundError:
return self.response_404()
buf.seek(0)
return send_file(
buf,
mimetype="application/zip",
as_attachment=True,
attachment_filename=filename,
)

View File

@ -28,7 +28,7 @@ from superset.security.analytics_db_safety import DBSecurityException
class DatabaseInvalidError(CommandInvalidError):
message = _("Dashboard parameters are invalid.")
message = _("Database parameters are invalid.")
class DatabaseExistsValidationError(ValidationError):

View File

@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import json
from typing import Iterator, List, Tuple
import yaml
from superset.commands.base import BaseCommand
from superset.databases.commands.exceptions import DatabaseNotFoundError
from superset.databases.dao import DatabaseDAO
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize
from superset.models.core import Database
class ExportDatabasesCommand(BaseCommand):
def __init__(self, database_ids: List[int]):
self.database_ids = database_ids
# this will be set when calling validate()
self._models: List[Database] = []
@staticmethod
def export_database(database: Database) -> Iterator[Tuple[str, str]]:
name = sanitize(database.database_name)
file_name = f"databases/{name}.yaml"
payload = database.export_to_dict(
recursive=False,
include_parent_ref=False,
include_defaults=True,
export_uuids=True,
)
# TODO (betodealmeida): move this logic to export_to_dict once this
# becomes the default export endpoint
if "extra" in payload:
try:
payload["extra"] = json.loads(payload["extra"])
except json.decoder.JSONDecodeError:
pass
payload["version"] = IMPORT_EXPORT_VERSION
file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content
# TODO (betodealmeida): reuse logic from ExportDatasetCommand once
# it's implemented
for dataset in database.tables:
name = sanitize(dataset.table_name)
file_name = f"datasets/{name}.yaml"
payload = dataset.export_to_dict(
recursive=True,
include_parent_ref=False,
include_defaults=True,
export_uuids=True,
)
payload["version"] = IMPORT_EXPORT_VERSION
payload["database_uuid"] = str(database.uuid)
file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content
def run(self) -> Iterator[Tuple[str, str]]:
self.validate()
for database in self._models:
yield from self.export_database(database)
def validate(self) -> None:
self._models = DatabaseDAO.find_by_ids(self.database_ids)
if len(self._models) != len(self.database_ids):
raise DatabaseNotFoundError()

View File

@ -109,6 +109,7 @@ extra_description = markdown(
"whether or not the Explore button in SQL Lab results is shown.",
True,
)
get_export_ids_schema = {"type": "array", "items": {"type": "integer"}}
sqlalchemy_uri_description = markdown(
"Refer to the "
"[SqlAlchemy docs]"

View File

@ -52,6 +52,23 @@ def json_to_dict(json_str: str) -> Dict[Any, Any]:
return {}
def convert_uuids(obj: Any) -> Any:
"""
Convert UUID objects to str so we can use yaml.safe_dump
"""
if isinstance(obj, uuid.UUID):
return str(obj)
if isinstance(obj, list):
return [convert_uuids(el) for el in obj]
if isinstance(obj, dict):
return {k: convert_uuids(v) for k, v in obj.items()}
return obj
# TODO (betodealmeida): rename to ImportExportMixin
class ImportMixin:
uuid = sa.Column(
UUIDType(binary=True), primary_key=False, unique=True, default=uuid.uuid4
@ -247,8 +264,15 @@ class ImportMixin:
recursive: bool = True,
include_parent_ref: bool = False,
include_defaults: bool = False,
export_uuids: bool = False,
) -> Dict[Any, Any]:
"""Export obj to dictionary"""
export_fields = set(self.export_fields)
if export_uuids:
export_fields.add("uuid")
if "id" in export_fields:
export_fields.remove("id")
cls = self.__class__
parent_excludes = set()
if recursive and not include_parent_ref:
@ -259,7 +283,7 @@ class ImportMixin:
c.name: getattr(self, c.name)
for c in cls.__table__.columns # type: ignore
if (
c.name in self.export_fields
c.name in export_fields
and c.name not in parent_excludes
and (
include_defaults
@ -270,6 +294,13 @@ class ImportMixin:
)
)
}
# sort according to export_fields using DSU (decorate, sort, undecorate)
order = {field: i for i, field in enumerate(self.export_fields)}
decorated_keys = [(order.get(k, len(order)), k) for k in dict_rep]
decorated_keys.sort()
dict_rep = {k: dict_rep[k] for _, k in decorated_keys}
if recursive:
for cld in self.export_children:
# sorting to make lists of children stable
@ -285,7 +316,7 @@ class ImportMixin:
key=lambda k: sorted(str(k.items())),
)
return dict_rep
return convert_uuids(dict_rep)
def override(self, obj: Any) -> None:
"""Overrides the plain fields of the dashboard."""

View File

@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import logging
import re
import unicodedata
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
@ -22,6 +24,7 @@ from sqlalchemy.orm import Session
from superset.connectors.druid.models import DruidCluster
from superset.models.core import Database
IMPORT_EXPORT_VERSION = "1.0.0"
DATABASES_KEY = "databases"
DRUID_CLUSTERS_KEY = "druid_clusters"
logger = logging.getLogger(__name__)
@ -95,3 +98,18 @@ def import_from_dict(
session.commit()
else:
logger.info("Supplied object is not a dictionary.")
def strip_accents(text: str) -> str:
text = unicodedata.normalize("NFD", text).encode("ascii", "ignore").decode("utf-8")
return str(text)
def sanitize(name: str) -> str:
"""Sanitize a post title into a directory name."""
name = name.lower().replace(" ", "_")
name = re.sub(r"[^\w]", "", name)
name = strip_accents(name)
return name

View File

@ -962,7 +962,6 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
print(rv.data)
self.assertEqual(rv.status_code, 200)
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]

View File

@ -24,6 +24,7 @@ import prison
from sqlalchemy.sql import func
import tests.test_app
from freezegun import freeze_time
from sqlalchemy import and_
from superset import db, security_manager
from superset.models.dashboard import Dashboard
@ -955,12 +956,14 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin):
self.login(username="admin")
argument = [1, 2]
uri = f"api/v1/dashboard/export/?q={prison.dumps(argument)}"
rv = self.get_assert_metric(uri, "export")
self.assertEqual(rv.status_code, 200)
self.assertEqual(
rv.headers["Content-Disposition"],
generate_download_headers("json")["Content-Disposition"],
)
# freeze time to ensure filename is deterministic
with freeze_time("2020-01-01T00:00:00Z"):
rv = self.get_assert_metric(uri, "export")
headers = generate_download_headers("json")["Content-Disposition"]
assert rv.status_code == 200
assert rv.headers["Content-Disposition"] == headers
def test_export_not_found(self):
"""

View File

@ -18,6 +18,8 @@
"""Unit tests for Superset"""
import datetime
import json
from io import BytesIO
from zipfile import is_zipfile
import pandas as pd
import prison
@ -801,3 +803,45 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_export_database(self):
"""
Database API: Test export database
"""
self.login(username="admin")
database = get_example_database()
argument = [database.id]
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
rv = self.client.get(uri)
assert rv.status_code == 200
buf = BytesIO(rv.data)
assert is_zipfile(buf)
def test_export_database_not_allowed(self):
"""
Database API: Test export database not allowed
"""
self.login(username="gamma")
database = get_example_database()
argument = [database.id]
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
rv = self.client.get(uri)
assert rv.status_code == 401
def test_export_database_non_existing(self):
"""
Database API: Test export database not allowed
"""
max_id = db.session.query(func.max(Database.id)).scalar()
# id does not exist and we get 404
invalid_id = max_id + 1
self.login(username="admin")
argument = [invalid_id]
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
rv = self.client.get(uri)
assert rv.status_code == 404

View File

@ -0,0 +1,266 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import patch
import yaml
from superset import db, security_manager
from superset.databases.commands.exceptions import DatabaseNotFoundError
from superset.databases.commands.export import ExportDatabasesCommand
from superset.models.core import Database
from superset.utils.core import backend, get_example_database
from tests.base_tests import SupersetTestCase
class TestExportDatabasesCommand(SupersetTestCase):
@patch("superset.security.manager.g")
def test_export_database_command(self, mock_g):
mock_g.user = security_manager.find_user("admin")
example_db = get_example_database()
command = ExportDatabasesCommand(database_ids=[example_db.id])
contents = dict(command.run())
# TODO: this list shouldn't depend on the order in which unit tests are run
# or on the backend; for now use a stable subset
core_datasets = {
"databases/examples.yaml",
"datasets/energy_usage.yaml",
"datasets/wb_health_population.yaml",
"datasets/birth_names.yaml",
}
expected_extra = {
"engine_params": {},
"metadata_cache_timeout": {},
"metadata_params": {},
"schemas_allowed_for_csv_upload": [],
}
if backend() == "presto":
expected_extra = {"engine_params": {"connect_args": {"poll_interval": 0.1}}}
assert core_datasets.issubset(set(contents.keys()))
metadata = yaml.safe_load(contents["databases/examples.yaml"])
assert metadata == (
{
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_run_async": False,
"cache_timeout": None,
"database_name": "examples",
"expose_in_sqllab": True,
"extra": expected_extra,
"sqlalchemy_uri": example_db.sqlalchemy_uri,
"uuid": str(example_db.uuid),
"version": "1.0.0",
}
)
metadata = yaml.safe_load(contents["datasets/birth_names.yaml"])
metadata.pop("uuid")
assert metadata == {
"table_name": "birth_names",
"main_dttm_col": None,
"description": "Adding a DESCRip",
"default_endpoint": "",
"offset": 66,
"cache_timeout": 55,
"schema": "",
"sql": "",
"params": None,
"template_params": None,
"filter_select_enabled": True,
"fetch_values_predicate": None,
"metrics": [
{
"metric_name": "ratio",
"verbose_name": "Ratio Boys/Girls",
"metric_type": None,
"expression": "sum(sum_boys) / sum(sum_girls)",
"description": "This represents the ratio of boys/girls",
"d3format": ".2%",
"extra": None,
"warning_text": "no warning",
},
{
"metric_name": "sum__num",
"verbose_name": "Babies",
"metric_type": None,
"expression": "SUM(num)",
"description": "",
"d3format": "",
"extra": None,
"warning_text": "",
},
{
"metric_name": "count",
"verbose_name": "",
"metric_type": None,
"expression": "count(1)",
"description": None,
"d3format": None,
"extra": None,
"warning_text": None,
},
],
"columns": [
{
"column_name": "num_california",
"verbose_name": None,
"is_dttm": False,
"is_active": None,
"type": "NUMBER",
"groupby": False,
"filterable": False,
"expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END",
"description": None,
"python_date_format": None,
},
{
"column_name": "ds",
"verbose_name": "",
"is_dttm": True,
"is_active": None,
"type": "DATETIME",
"groupby": True,
"filterable": True,
"expression": "",
"description": None,
"python_date_format": None,
},
{
"column_name": "sum_girls",
"verbose_name": None,
"is_dttm": False,
"is_active": None,
"type": "BIGINT(20)",
"groupby": False,
"filterable": False,
"expression": "",
"description": None,
"python_date_format": None,
},
{
"column_name": "gender",
"verbose_name": None,
"is_dttm": False,
"is_active": None,
"type": "VARCHAR(16)",
"groupby": True,
"filterable": True,
"expression": "",
"description": None,
"python_date_format": None,
},
{
"column_name": "state",
"verbose_name": None,
"is_dttm": None,
"is_active": None,
"type": "VARCHAR(10)",
"groupby": True,
"filterable": True,
"expression": None,
"description": None,
"python_date_format": None,
},
{
"column_name": "sum_boys",
"verbose_name": None,
"is_dttm": None,
"is_active": None,
"type": "BIGINT(20)",
"groupby": True,
"filterable": True,
"expression": None,
"description": None,
"python_date_format": None,
},
{
"column_name": "num",
"verbose_name": None,
"is_dttm": None,
"is_active": None,
"type": "BIGINT(20)",
"groupby": True,
"filterable": True,
"expression": None,
"description": None,
"python_date_format": None,
},
{
"column_name": "name",
"verbose_name": None,
"is_dttm": None,
"is_active": None,
"type": "VARCHAR(255)",
"groupby": True,
"filterable": True,
"expression": None,
"description": None,
"python_date_format": None,
},
],
"version": "1.0.0",
"database_uuid": str(example_db.uuid),
}
@patch("superset.security.manager.g")
def test_export_database_command_no_access(self, mock_g):
"""Test that users can't export databases they don't have access to"""
mock_g.user = security_manager.find_user("gamma")
example_db = get_example_database()
command = ExportDatabasesCommand(database_ids=[example_db.id])
contents = command.run()
with self.assertRaises(DatabaseNotFoundError):
next(contents)
@patch("superset.security.manager.g")
def test_export_database_command_invalid_database(self, mock_g):
"""Test that an error is raised when exporting an invalid database"""
mock_g.user = security_manager.find_user("admin")
command = ExportDatabasesCommand(database_ids=[-1])
contents = command.run()
with self.assertRaises(DatabaseNotFoundError):
next(contents)
@patch("superset.security.manager.g")
def test_export_database_command_key_order(self, mock_g):
"""Test that they keys in the YAML have the same order as export_fields"""
mock_g.user = security_manager.find_user("admin")
example_db = get_example_database()
command = ExportDatabasesCommand(database_ids=[example_db.id])
contents = dict(command.run())
metadata = yaml.safe_load(contents["databases/examples.yaml"])
assert list(metadata.keys()) == [
"database_name",
"sqlalchemy_uri",
"cache_timeout",
"expose_in_sqllab",
"allow_run_async",
"allow_ctas",
"allow_cvas",
"allow_csv_upload",
"extra",
"uuid",
"version",
]

View File

@ -216,10 +216,9 @@ class TestDatasetApi(SupersetTestCase):
"admin_database",
"information_schema",
"public",
"superset",
]
expected_response = {
"count": 5,
"count": 4,
"result": [{"text": val, "value": val} for val in schema_values],
}
self.login(username="admin")
@ -243,14 +242,14 @@ class TestDatasetApi(SupersetTestCase):
query_parameter = {"page": 0, "page_size": 1}
pg_test_query_parameter(
query_parameter, {"count": 5, "result": [{"text": "", "value": ""}]},
query_parameter, {"count": 4, "result": [{"text": "", "value": ""}]},
)
query_parameter = {"page": 1, "page_size": 1}
pg_test_query_parameter(
query_parameter,
{
"count": 5,
"count": 4,
"result": [{"text": "admin_database", "value": "admin_database"}],
},
)

View File

@ -20,6 +20,7 @@ from unittest.mock import patch
import pytest
import tests.test_app
from superset import db
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.exceptions import QueryObjectValidationError
@ -86,49 +87,54 @@ class TestDatabaseModel(SupersetTestCase):
}
# Table with Jinja callable.
table = SqlaTable(
table1 = SqlaTable(
table_name="test_has_extra_cache_keys_table",
sql="SELECT '{{ current_username() }}' as user",
database=get_example_database(),
)
query_obj = dict(**base_query_obj, extras={})
extra_cache_keys = table.get_extra_cache_keys(query_obj)
self.assertTrue(table.has_extra_cache_key_calls(query_obj))
extra_cache_keys = table1.get_extra_cache_keys(query_obj)
self.assertTrue(table1.has_extra_cache_key_calls(query_obj))
assert extra_cache_keys == ["abc"]
# Table with Jinja callable disabled.
table = SqlaTable(
table2 = SqlaTable(
table_name="test_has_extra_cache_keys_disabled_table",
sql="SELECT '{{ current_username(False) }}' as user",
database=get_example_database(),
)
query_obj = dict(**base_query_obj, extras={})
extra_cache_keys = table.get_extra_cache_keys(query_obj)
self.assertTrue(table.has_extra_cache_key_calls(query_obj))
extra_cache_keys = table2.get_extra_cache_keys(query_obj)
self.assertTrue(table2.has_extra_cache_key_calls(query_obj))
self.assertListEqual(extra_cache_keys, [])
# Table with no Jinja callable.
query = "SELECT 'abc' as user"
table = SqlaTable(
table3 = SqlaTable(
table_name="test_has_no_extra_cache_keys_table",
sql=query,
database=get_example_database(),
)
query_obj = dict(**base_query_obj, extras={"where": "(user != 'abc')"})
extra_cache_keys = table.get_extra_cache_keys(query_obj)
self.assertFalse(table.has_extra_cache_key_calls(query_obj))
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
self.assertFalse(table3.has_extra_cache_key_calls(query_obj))
self.assertListEqual(extra_cache_keys, [])
# With Jinja callable in SQL expression.
query_obj = dict(
**base_query_obj, extras={"where": "(user != '{{ current_username() }}')"}
)
extra_cache_keys = table.get_extra_cache_keys(query_obj)
self.assertTrue(table.has_extra_cache_key_calls(query_obj))
extra_cache_keys = table3.get_extra_cache_keys(query_obj)
self.assertTrue(table3.has_extra_cache_key_calls(query_obj))
assert extra_cache_keys == ["abc"]
# Cleanup
for table in [table1, table2, table3]:
db.session.delete(table)
db.session.commit()
def test_where_operators(self):
class FilterTestCase(NamedTuple):
operator: str

View File

@ -384,6 +384,10 @@ class TestSqlLab(SupersetTestCase):
view_menu = security_manager.find_view_menu(table.get_perm())
assert view_menu is not None
# Cleanup
db.session.delete(table)
db.session.commit()
def test_sqllab_viz_bad_payload(self):
self.login("admin")
payload = {