feat(chart-data-api): download multiple csvs as zip (#18618)
* feat(chart-data-api): download multiple csvs as zip * break out util * check for empty request
This commit is contained in:
parent
9c08bc0ffc
commit
125be78ee6
|
|
@ -21,7 +21,7 @@ import logging
|
|||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
import simplejson
|
||||
from flask import g, make_response, request
|
||||
from flask import current_app, g, make_response, request, Response
|
||||
from flask_appbuilder.api import expose, protect
|
||||
from flask_babel import gettext as _
|
||||
from marshmallow import ValidationError
|
||||
|
|
@ -44,13 +44,11 @@ from superset.connectors.base.models import BaseDatasource
|
|||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.extensions import event_logger
|
||||
from superset.utils.async_query_manager import AsyncQueryTokenException
|
||||
from superset.utils.core import json_int_dttm_ser
|
||||
from superset.utils.core import create_zip, json_int_dttm_ser
|
||||
from superset.views.base import CsvResponse, generate_download_headers
|
||||
from superset.views.base_api import statsd_metrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from flask import Response
|
||||
|
||||
from superset.common.query_context import QueryContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -350,9 +348,25 @@ class ChartDataRestApi(ChartRestApi):
|
|||
if not security_manager.can_access("can_csv", "Superset"):
|
||||
return self.response_403()
|
||||
|
||||
# return the first result
|
||||
data = result["queries"][0]["data"]
|
||||
return CsvResponse(data, headers=generate_download_headers("csv"))
|
||||
if not result["queries"]:
|
||||
return self.response_400(_("Empty query result"))
|
||||
|
||||
if len(result["queries"]) == 1:
|
||||
# return single query results csv format
|
||||
data = result["queries"][0]["data"]
|
||||
return CsvResponse(data, headers=generate_download_headers("csv"))
|
||||
|
||||
# return multi-query csv results bundled as a zip file
|
||||
encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8")
|
||||
files = {
|
||||
f"query_{idx + 1}.csv": result["data"].encode(encoding)
|
||||
for idx, result in enumerate(result["queries"])
|
||||
}
|
||||
return Response(
|
||||
create_zip(files),
|
||||
headers=generate_download_headers("zip"),
|
||||
mimetype="application/zip",
|
||||
)
|
||||
|
||||
if result_format == ChartDataResultFormat.JSON:
|
||||
response_data = simplejson.dumps(
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from email.mime.multipart import MIMEMultipart
|
|||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
from enum import Enum, IntEnum
|
||||
from io import BytesIO
|
||||
from timeit import default_timer
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
|
|
@ -61,6 +62,7 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
from urllib.parse import unquote_plus
|
||||
from zipfile import ZipFile
|
||||
|
||||
import bleach
|
||||
import markdown as md
|
||||
|
|
@ -1788,3 +1790,13 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
|
|||
if limit != 0:
|
||||
return min(max_limit, limit)
|
||||
return max_limit
|
||||
|
||||
|
||||
def create_zip(files: Dict[str, Any]) -> BytesIO:
|
||||
buf = BytesIO()
|
||||
with ZipFile(buf, "w") as bundle:
|
||||
for filename, contents in files.items():
|
||||
with bundle.open(filename, "w") as fp:
|
||||
fp.write(contents)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
|
|
|||
|
|
@ -20,8 +20,11 @@ import json
|
|||
import unittest
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
from zipfile import ZipFile
|
||||
|
||||
from flask import Response
|
||||
from tests.integration_tests.conftest import with_feature_flags
|
||||
from superset.models.sql_lab import Query
|
||||
|
|
@ -235,6 +238,16 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
assert rv.status_code == 200
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_empty_request_with_csv_result_format(self):
|
||||
"""
|
||||
Chart data API: Test empty chart data with CSV result format
|
||||
"""
|
||||
self.query_context_payload["result_format"] = "csv"
|
||||
self.query_context_payload["queries"] = []
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
assert rv.status_code == 400
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_with_csv_result_format(self):
|
||||
"""
|
||||
|
|
@ -243,6 +256,22 @@ class TestPostChartDataApi(BaseTestChartDataApi):
|
|||
self.query_context_payload["result_format"] = "csv"
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
assert rv.status_code == 200
|
||||
assert rv.mimetype == "text/csv"
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_with_multi_query_csv_result_format(self):
|
||||
"""
|
||||
Chart data API: Test chart data with multi-query CSV result format
|
||||
"""
|
||||
self.query_context_payload["result_format"] = "csv"
|
||||
self.query_context_payload["queries"].append(
|
||||
self.query_context_payload["queries"][0]
|
||||
)
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
|
||||
assert rv.status_code == 200
|
||||
assert rv.mimetype == "application/zip"
|
||||
zipfile = ZipFile(BytesIO(rv.data), "r")
|
||||
assert zipfile.namelist() == ["query_1.csv", "query_2.csv"]
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_with_csv_result_format_when_actor_not_permitted_for_csv__403(self):
|
||||
|
|
@ -766,6 +795,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
|
|||
}
|
||||
)
|
||||
rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
|
||||
assert rv.mimetype == "application/json"
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data["result"][0]["status"] == "success"
|
||||
assert data["result"][0]["rowcount"] == 2
|
||||
|
|
|
|||
Loading…
Reference in New Issue