feat(chart-data-api): download multiple csvs as zip (#18618)

* feat(chart-data-api): download multiple csvs as zip

* break out util

* check for empty request
This commit is contained in:
Ville Brofeldt 2022-02-08 21:25:06 +02:00 committed by GitHub
parent 9c08bc0ffc
commit 125be78ee6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 7 deletions

View File

@ -21,7 +21,7 @@ import logging
from typing import Any, Dict, Optional, TYPE_CHECKING
import simplejson
from flask import g, make_response, request
from flask import current_app, g, make_response, request, Response
from flask_appbuilder.api import expose, protect
from flask_babel import gettext as _
from marshmallow import ValidationError
@ -44,13 +44,11 @@ from superset.connectors.base.models import BaseDatasource
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.core import json_int_dttm_ser
from superset.utils.core import create_zip, json_int_dttm_ser
from superset.views.base import CsvResponse, generate_download_headers
from superset.views.base_api import statsd_metrics
if TYPE_CHECKING:
from flask import Response
from superset.common.query_context import QueryContext
logger = logging.getLogger(__name__)
@ -350,9 +348,25 @@ class ChartDataRestApi(ChartRestApi):
if not security_manager.can_access("can_csv", "Superset"):
return self.response_403()
# return the first result
data = result["queries"][0]["data"]
return CsvResponse(data, headers=generate_download_headers("csv"))
if not result["queries"]:
return self.response_400(_("Empty query result"))
if len(result["queries"]) == 1:
# return single query results csv format
data = result["queries"][0]["data"]
return CsvResponse(data, headers=generate_download_headers("csv"))
# return multi-query csv results bundled as a zip file
encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8")
files = {
f"query_{idx + 1}.csv": result["data"].encode(encoding)
for idx, result in enumerate(result["queries"])
}
return Response(
create_zip(files),
headers=generate_download_headers("zip"),
mimetype="application/zip",
)
if result_format == ChartDataResultFormat.JSON:
response_data = simplejson.dumps(

View File

@ -40,6 +40,7 @@ from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
from enum import Enum, IntEnum
from io import BytesIO
from timeit import default_timer
from types import TracebackType
from typing import (
@ -61,6 +62,7 @@ from typing import (
Union,
)
from urllib.parse import unquote_plus
from zipfile import ZipFile
import bleach
import markdown as md
@ -1788,3 +1790,13 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
if limit != 0:
return min(max_limit, limit)
return max_limit
def create_zip(files: Dict[str, Any]) -> BytesIO:
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
for filename, contents in files.items():
with bundle.open(filename, "w") as fp:
fp.write(contents)
buf.seek(0)
return buf

View File

@ -20,8 +20,11 @@ import json
import unittest
import copy
from datetime import datetime
from io import BytesIO
from typing import Optional
from unittest import mock
from zipfile import ZipFile
from flask import Response
from tests.integration_tests.conftest import with_feature_flags
from superset.models.sql_lab import Query
@ -235,6 +238,16 @@ class TestPostChartDataApi(BaseTestChartDataApi):
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 200
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_empty_request_with_csv_result_format(self):
"""
Chart data API: Test empty chart data with CSV result format
"""
self.query_context_payload["result_format"] = "csv"
self.query_context_payload["queries"] = []
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 400
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_csv_result_format(self):
"""
@ -243,6 +256,22 @@ class TestPostChartDataApi(BaseTestChartDataApi):
self.query_context_payload["result_format"] = "csv"
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 200
assert rv.mimetype == "text/csv"
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_multi_query_csv_result_format(self):
"""
Chart data API: Test chart data with multi-query CSV result format
"""
self.query_context_payload["result_format"] = "csv"
self.query_context_payload["queries"].append(
self.query_context_payload["queries"][0]
)
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 200
assert rv.mimetype == "application/zip"
zipfile = ZipFile(BytesIO(rv.data), "r")
assert zipfile.namelist() == ["query_1.csv", "query_2.csv"]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_csv_result_format_when_actor_not_permitted_for_csv__403(self):
@ -766,6 +795,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
}
)
rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
assert rv.mimetype == "application/json"
data = json.loads(rv.data.decode("utf-8"))
assert data["result"][0]["status"] == "success"
assert data["result"][0]["rowcount"] == 2