437 lines
16 KiB
Python
437 lines
16 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
import simplejson
|
|
from flask import current_app, g, make_response, request, Response
|
|
from flask_appbuilder.api import expose, protect
|
|
from flask_babel import gettext as _
|
|
from marshmallow import ValidationError
|
|
|
|
from superset import is_feature_enabled, security_manager
|
|
from superset.charts.api import ChartRestApi
|
|
from superset.charts.commands.exceptions import (
|
|
ChartDataCacheLoadError,
|
|
ChartDataQueryFailedError,
|
|
)
|
|
from superset.charts.data.commands.create_async_job_command import (
|
|
CreateAsyncChartDataJobCommand,
|
|
)
|
|
from superset.charts.data.commands.get_data_command import ChartDataCommand
|
|
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
|
|
from superset.charts.post_processing import apply_post_process
|
|
from superset.charts.schemas import ChartDataQueryContextSchema
|
|
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
|
from superset.connectors.base.models import BaseDatasource
|
|
from superset.dao.exceptions import DatasourceNotFound
|
|
from superset.exceptions import QueryObjectValidationError
|
|
from superset.extensions import event_logger
|
|
from superset.models.sql_lab import Query
|
|
from superset.utils.async_query_manager import AsyncQueryTokenException
|
|
from superset.utils.core import create_zip, get_user_id, json_int_dttm_ser
|
|
from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse
|
|
from superset.views.base_api import statsd_metrics
|
|
|
|
if TYPE_CHECKING:
|
|
from superset.common.query_context import QueryContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChartDataRestApi(ChartRestApi):
|
|
include_route_methods = {"get_data", "data", "data_from_cache"}
|
|
|
|
@expose("/<int:pk>/data/", methods=("GET",))
|
|
@protect()
|
|
@statsd_metrics
|
|
@event_logger.log_this_with_context(
|
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data",
|
|
log_to_statsd=False,
|
|
)
|
|
def get_data(self, pk: int) -> Response:
|
|
"""
|
|
Takes a chart ID and uses the query context stored when the chart was saved
|
|
to return payload data response.
|
|
---
|
|
get:
|
|
description: >-
|
|
Takes a chart ID and uses the query context stored when the chart was saved
|
|
to return payload data response.
|
|
parameters:
|
|
- in: path
|
|
schema:
|
|
type: integer
|
|
name: pk
|
|
description: The chart ID
|
|
- in: query
|
|
name: format
|
|
description: The format in which the data should be returned
|
|
schema:
|
|
type: string
|
|
- in: query
|
|
name: type
|
|
description: The type in which the data should be returned
|
|
schema:
|
|
type: string
|
|
- in: query
|
|
name: force
|
|
description: Should the queries be forced to load from the source
|
|
schema:
|
|
type: boolean
|
|
responses:
|
|
200:
|
|
description: Query result
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataResponseSchema"
|
|
202:
|
|
description: Async job details
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
|
|
400:
|
|
$ref: '#/components/responses/400'
|
|
401:
|
|
$ref: '#/components/responses/401'
|
|
500:
|
|
$ref: '#/components/responses/500'
|
|
"""
|
|
chart = self.datamodel.get(pk, self._base_filters)
|
|
if not chart:
|
|
return self.response_404()
|
|
|
|
try:
|
|
json_body = json.loads(chart.query_context)
|
|
except (TypeError, json.decoder.JSONDecodeError):
|
|
json_body = None
|
|
|
|
if json_body is None:
|
|
return self.response_400(
|
|
message=_(
|
|
"Chart has no query context saved. Please save the chart again."
|
|
)
|
|
)
|
|
|
|
# override saved query context
|
|
json_body["result_format"] = request.args.get(
|
|
"format", ChartDataResultFormat.JSON
|
|
)
|
|
json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL)
|
|
json_body["force"] = request.args.get("force")
|
|
|
|
try:
|
|
query_context = self._create_query_context_from_form(json_body)
|
|
command = ChartDataCommand(query_context)
|
|
command.validate()
|
|
except DatasourceNotFound as error:
|
|
return self.response_404()
|
|
except QueryObjectValidationError as error:
|
|
return self.response_400(message=error.message)
|
|
except ValidationError as error:
|
|
return self.response_400(
|
|
message=_(
|
|
"Request is incorrect: %(error)s", error=error.normalized_messages()
|
|
)
|
|
)
|
|
|
|
# TODO: support CSV, SQL query and other non-JSON types
|
|
if (
|
|
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
|
|
and query_context.result_format == ChartDataResultFormat.JSON
|
|
and query_context.result_type == ChartDataResultType.FULL
|
|
):
|
|
return self._run_async(json_body, command)
|
|
|
|
try:
|
|
form_data = json.loads(chart.params)
|
|
except (TypeError, json.decoder.JSONDecodeError):
|
|
form_data = {}
|
|
|
|
return self._get_data_response(
|
|
command=command, form_data=form_data, datasource=query_context.datasource
|
|
)
|
|
|
|
@expose("/data", methods=("POST",))
|
|
@protect()
|
|
@statsd_metrics
|
|
@event_logger.log_this_with_context(
|
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data",
|
|
log_to_statsd=False,
|
|
)
|
|
def data(self) -> Response:
|
|
"""
|
|
Takes a query context constructed in the client and returns payload
|
|
data response for the given query.
|
|
---
|
|
post:
|
|
description: >-
|
|
Takes a query context constructed in the client and returns payload data
|
|
response for the given query.
|
|
requestBody:
|
|
description: >-
|
|
A query context consists of a datasource from which to fetch data
|
|
and one or many query objects.
|
|
required: true
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataQueryContextSchema"
|
|
responses:
|
|
200:
|
|
description: Query result
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataResponseSchema"
|
|
202:
|
|
description: Async job details
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
|
|
400:
|
|
$ref: '#/components/responses/400'
|
|
401:
|
|
$ref: '#/components/responses/401'
|
|
500:
|
|
$ref: '#/components/responses/500'
|
|
"""
|
|
json_body = None
|
|
if request.is_json:
|
|
json_body = request.json
|
|
elif request.form.get("form_data"):
|
|
# CSV export submits regular form data
|
|
try:
|
|
json_body = json.loads(request.form["form_data"])
|
|
except (TypeError, json.JSONDecodeError):
|
|
pass
|
|
|
|
if json_body is None:
|
|
return self.response_400(message=_("Request is not JSON"))
|
|
|
|
try:
|
|
query_context = self._create_query_context_from_form(json_body)
|
|
command = ChartDataCommand(query_context)
|
|
command.validate()
|
|
except DatasourceNotFound as error:
|
|
return self.response_404()
|
|
except QueryObjectValidationError as error:
|
|
return self.response_400(message=error.message)
|
|
except ValidationError as error:
|
|
return self.response_400(
|
|
message=_(
|
|
"Request is incorrect: %(error)s", error=error.normalized_messages()
|
|
)
|
|
)
|
|
|
|
# TODO: support CSV, SQL query and other non-JSON types
|
|
if (
|
|
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
|
|
and query_context.result_format == ChartDataResultFormat.JSON
|
|
and query_context.result_type == ChartDataResultType.FULL
|
|
):
|
|
return self._run_async(json_body, command)
|
|
|
|
form_data = json_body.get("form_data")
|
|
return self._get_data_response(
|
|
command, form_data=form_data, datasource=query_context.datasource
|
|
)
|
|
|
|
@expose("/data/<cache_key>", methods=("GET",))
|
|
@protect()
|
|
@statsd_metrics
|
|
@event_logger.log_this_with_context(
|
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
|
|
f".data_from_cache",
|
|
log_to_statsd=False,
|
|
)
|
|
def data_from_cache(self, cache_key: str) -> Response:
|
|
"""
|
|
Takes a query context cache key and returns payload
|
|
data response for the given query.
|
|
---
|
|
get:
|
|
description: >-
|
|
Takes a query context cache key and returns payload data
|
|
response for the given query.
|
|
parameters:
|
|
- in: path
|
|
schema:
|
|
type: string
|
|
name: cache_key
|
|
responses:
|
|
200:
|
|
description: Query result
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataResponseSchema"
|
|
400:
|
|
$ref: '#/components/responses/400'
|
|
401:
|
|
$ref: '#/components/responses/401'
|
|
404:
|
|
$ref: '#/components/responses/404'
|
|
422:
|
|
$ref: '#/components/responses/422'
|
|
500:
|
|
$ref: '#/components/responses/500'
|
|
"""
|
|
try:
|
|
cached_data = self._load_query_context_form_from_cache(cache_key)
|
|
# Set form_data in Flask Global as it is used as a fallback
|
|
# for async queries with jinja context
|
|
setattr(g, "form_data", cached_data)
|
|
query_context = self._create_query_context_from_form(cached_data)
|
|
command = ChartDataCommand(query_context)
|
|
command.validate()
|
|
except ChartDataCacheLoadError:
|
|
return self.response_404()
|
|
except ValidationError as error:
|
|
return self.response_400(
|
|
message=_("Request is incorrect: %(error)s", error=error.messages)
|
|
)
|
|
|
|
return self._get_data_response(command, True)
|
|
|
|
def _run_async(
|
|
self, form_data: dict[str, Any], command: ChartDataCommand
|
|
) -> Response:
|
|
"""
|
|
Execute command as an async query.
|
|
"""
|
|
# First, look for the chart query results in the cache.
|
|
result = None
|
|
try:
|
|
result = command.run(force_cached=True)
|
|
if result is not None:
|
|
return self._send_chart_response(result)
|
|
except ChartDataCacheLoadError:
|
|
pass
|
|
|
|
# Otherwise, kick off a background job to run the chart query.
|
|
# Clients will either poll or be notified of query completion,
|
|
# at which point they will call the /data/<cache_key> endpoint
|
|
# to retrieve the results.
|
|
async_command = CreateAsyncChartDataJobCommand()
|
|
try:
|
|
async_command.validate(request)
|
|
except AsyncQueryTokenException:
|
|
return self.response_401()
|
|
|
|
result = async_command.run(form_data, get_user_id())
|
|
return self.response(202, **result)
|
|
|
|
def _send_chart_response(
|
|
self,
|
|
result: dict[Any, Any],
|
|
form_data: dict[str, Any] | None = None,
|
|
datasource: BaseDatasource | Query | None = None,
|
|
) -> Response:
|
|
result_type = result["query_context"].result_type
|
|
result_format = result["query_context"].result_format
|
|
|
|
# Post-process the data so it matches the data presented in the chart.
|
|
# This is needed for sending reports based on text charts that do the
|
|
# post-processing of data, eg, the pivot table.
|
|
if result_type == ChartDataResultType.POST_PROCESSED:
|
|
result = apply_post_process(result, form_data, datasource)
|
|
|
|
if result_format in ChartDataResultFormat.table_like():
|
|
# Verify user has permission to export file
|
|
if not security_manager.can_access("can_csv", "Superset"):
|
|
return self.response_403()
|
|
|
|
if not result["queries"]:
|
|
return self.response_400(_("Empty query result"))
|
|
|
|
is_csv_format = result_format == ChartDataResultFormat.CSV
|
|
|
|
if len(result["queries"]) == 1:
|
|
# return single query results
|
|
data = result["queries"][0]["data"]
|
|
if is_csv_format:
|
|
return CsvResponse(data, headers=generate_download_headers("csv"))
|
|
|
|
return XlsxResponse(data, headers=generate_download_headers("xlsx"))
|
|
|
|
# return multi-query results bundled as a zip file
|
|
def _process_data(query_data: Any) -> Any:
|
|
if result_format == ChartDataResultFormat.CSV:
|
|
encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8")
|
|
return query_data.encode(encoding)
|
|
return query_data
|
|
|
|
files = {
|
|
f"query_{idx + 1}.{result_format}": _process_data(query["data"])
|
|
for idx, query in enumerate(result["queries"])
|
|
}
|
|
return Response(
|
|
create_zip(files),
|
|
headers=generate_download_headers("zip"),
|
|
mimetype="application/zip",
|
|
)
|
|
|
|
if result_format == ChartDataResultFormat.JSON:
|
|
response_data = simplejson.dumps(
|
|
{"result": result["queries"]},
|
|
default=json_int_dttm_ser,
|
|
ignore_nan=True,
|
|
)
|
|
resp = make_response(response_data, 200)
|
|
resp.headers["Content-Type"] = "application/json; charset=utf-8"
|
|
return resp
|
|
|
|
return self.response_400(message=f"Unsupported result_format: {result_format}")
|
|
|
|
def _get_data_response(
|
|
self,
|
|
command: ChartDataCommand,
|
|
force_cached: bool = False,
|
|
form_data: dict[str, Any] | None = None,
|
|
datasource: BaseDatasource | Query | None = None,
|
|
) -> Response:
|
|
try:
|
|
result = command.run(force_cached=force_cached)
|
|
except ChartDataCacheLoadError as exc:
|
|
return self.response_422(message=exc.message)
|
|
except ChartDataQueryFailedError as exc:
|
|
return self.response_400(message=exc.message)
|
|
|
|
return self._send_chart_response(result, form_data, datasource)
|
|
|
|
# pylint: disable=invalid-name, no-self-use
|
|
def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]:
|
|
return QueryContextCacheLoader.load(cache_key)
|
|
|
|
# pylint: disable=no-self-use
|
|
def _create_query_context_from_form(
|
|
self, form_data: dict[str, Any]
|
|
) -> QueryContext:
|
|
try:
|
|
return ChartDataQueryContextSchema().load(form_data)
|
|
except KeyError as ex:
|
|
raise ValidationError("Request is incorrect") from ex
|
|
except ValidationError as error:
|
|
raise error
|