feat: Add new result formats and types to chart data API (#9841)
* feat: Add new result formats and types to chart data API * lint * Linting * Add language to query payload * Fix tests * simplify tests
This commit is contained in:
parent
368c85de1b
commit
a43a1d6303
|
|
@ -27,7 +27,7 @@ import ExploreChartPanel from './ExploreChartPanel';
|
|||
import ControlPanelsContainer from './ControlPanelsContainer';
|
||||
import SaveModal from './SaveModal';
|
||||
import QueryAndSaveBtns from './QueryAndSaveBtns';
|
||||
import { getExploreUrl, getExploreLongUrl } from '../exploreUtils';
|
||||
import { getExploreLongUrl } from '../exploreUtils';
|
||||
import { areObjectsEqual } from '../../reduxUtils';
|
||||
import { getFormDataFromControls } from '../controlUtils';
|
||||
import { chartPropShape } from '../../dashboard/util/propShapes';
|
||||
|
|
|
|||
|
|
@ -711,6 +711,14 @@ class ChartDataQueryContextSchema(Schema):
|
|||
description="Should the queries be forced to load from the source. "
|
||||
"Default: `false`",
|
||||
)
|
||||
result_type = fields.String(
|
||||
description="Type of results to return",
|
||||
validate=validate.OneOf(choices=("query", "results", "samples")),
|
||||
)
|
||||
result_format = fields.String(
|
||||
description="Format of result payload",
|
||||
validate=validate.OneOf(choices=("json", "csv")),
|
||||
)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
@post_load
|
||||
|
|
|
|||
|
|
@ -14,10 +14,11 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import copy
|
||||
import logging
|
||||
import pickle as pkl
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, ClassVar, Dict, List, Optional
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
|
@ -49,15 +50,19 @@ class QueryContext:
|
|||
queries: List[QueryObject]
|
||||
force: bool
|
||||
custom_cache_timeout: Optional[int]
|
||||
response_type: utils.ChartDataResponseType
|
||||
response_format: utils.ChartDataResponseFormat
|
||||
|
||||
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
|
||||
# a vanilla python type https://github.com/python/mypy/issues/5288
|
||||
def __init__(
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
datasource: Dict[str, Any],
|
||||
queries: List[Dict[str, Any]],
|
||||
force: bool = False,
|
||||
custom_cache_timeout: Optional[int] = None,
|
||||
response_format: Optional[utils.ChartDataResponseFormat] = None,
|
||||
response_type: Optional[utils.ChartDataResponseType] = None,
|
||||
) -> None:
|
||||
self.datasource = ConnectorRegistry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), db.session
|
||||
|
|
@ -65,6 +70,8 @@ class QueryContext:
|
|||
self.queries = [QueryObject(**query_obj) for query_obj in queries]
|
||||
self.force = force
|
||||
self.custom_cache_timeout = custom_cache_timeout
|
||||
self.response_format = response_format or utils.ChartDataResponseFormat.JSON
|
||||
self.response_type = response_type or utils.ChartDataResponseType.RESULTS
|
||||
|
||||
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
|
||||
"""Returns a pandas dataframe based on the query object"""
|
||||
|
|
@ -124,12 +131,32 @@ class QueryContext:
|
|||
if dtype.type == np.object_ and col in query_object.metrics:
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce")
|
||||
|
||||
@staticmethod
|
||||
def get_data(df: pd.DataFrame,) -> List[Dict]: # pylint: disable=no-self-use
|
||||
def get_data(
|
||||
self, df: pd.DataFrame,
|
||||
) -> Union[str, List[Dict[str, Any]]]: # pylint: disable=no-self-use
|
||||
if self.response_format == utils.ChartDataResponseFormat.CSV:
|
||||
include_index = not isinstance(df.index, pd.RangeIndex)
|
||||
result = df.to_csv(index=include_index, **config["CSV_EXPORT"])
|
||||
return result or ""
|
||||
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
|
||||
"""Returns a payload of metadata and data"""
|
||||
if self.response_type == utils.ChartDataResponseType.QUERY:
|
||||
return {
|
||||
"query": self.datasource.get_query_str(query_obj.to_dict()),
|
||||
"language": self.datasource.query_language,
|
||||
}
|
||||
if self.response_type == utils.ChartDataResponseType.SAMPLES:
|
||||
row_limit = query_obj.row_limit or 1000
|
||||
query_obj = copy.copy(query_obj)
|
||||
query_obj.groupby = []
|
||||
query_obj.metrics = []
|
||||
query_obj.post_processing = []
|
||||
query_obj.row_limit = row_limit
|
||||
query_obj.columns = [o.column_name for o in self.datasource.columns]
|
||||
|
||||
payload = self.get_df_payload(query_obj)
|
||||
df = payload["df"]
|
||||
status = payload["status"]
|
||||
|
|
@ -142,7 +169,7 @@ class QueryContext:
|
|||
return payload
|
||||
|
||||
def get_payload(self) -> List[Dict[str, Any]]:
|
||||
"""Get all the payloads from the arrays"""
|
||||
"""Get all the payloads from the QueryObjects"""
|
||||
return [self.get_single_payload(query_object) for query_object in self.queries]
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -1367,3 +1367,22 @@ class FilterOperator(str, Enum):
|
|||
IN = "IN"
|
||||
NOT_IN = "NOT IN"
|
||||
REGEX = "REGEX"
|
||||
|
||||
|
||||
class ChartDataResponseType(str, Enum):
|
||||
"""
|
||||
Chart data response type
|
||||
"""
|
||||
|
||||
QUERY = "query"
|
||||
RESULTS = "results"
|
||||
SAMPLES = "samples"
|
||||
|
||||
|
||||
class ChartDataResponseFormat(str, Enum):
|
||||
"""
|
||||
Chart data response format
|
||||
"""
|
||||
|
||||
CSV = "csv"
|
||||
JSON = "json"
|
||||
|
|
|
|||
|
|
@ -636,10 +636,8 @@ class Superset(BaseSupersetView):
|
|||
def get_samples(self, viz_obj):
|
||||
return self.json_response({"data": viz_obj.get_samples()})
|
||||
|
||||
def generate_json(
|
||||
self, viz_obj, csv=False, query=False, results=False, samples=False
|
||||
):
|
||||
if csv:
|
||||
def generate_json(self, viz_obj, response_type: Optional[str] = None) -> Response:
|
||||
if response_type == utils.ChartDataResponseFormat.CSV:
|
||||
return CsvResponse(
|
||||
viz_obj.get_csv(),
|
||||
status=200,
|
||||
|
|
@ -647,13 +645,13 @@ class Superset(BaseSupersetView):
|
|||
mimetype="application/csv",
|
||||
)
|
||||
|
||||
if query:
|
||||
if response_type == utils.ChartDataResponseType.QUERY:
|
||||
return self.get_query_string_response(viz_obj)
|
||||
|
||||
if results:
|
||||
if response_type == utils.ChartDataResponseType.RESULTS:
|
||||
return self.get_raw_results(viz_obj)
|
||||
|
||||
if samples:
|
||||
if response_type == utils.ChartDataResponseType.SAMPLES:
|
||||
return self.get_samples(viz_obj)
|
||||
|
||||
payload = viz_obj.get_payload()
|
||||
|
|
@ -715,11 +713,14 @@ class Superset(BaseSupersetView):
|
|||
payloads based on the request args in the first block
|
||||
|
||||
TODO: break into one endpoint for each return shape"""
|
||||
csv = request.args.get("csv") == "true"
|
||||
query = request.args.get("query") == "true"
|
||||
results = request.args.get("results") == "true"
|
||||
samples = request.args.get("samples") == "true"
|
||||
force = request.args.get("force") == "true"
|
||||
response_type = utils.ChartDataResponseFormat.JSON.value
|
||||
responses = [resp_format for resp_format in utils.ChartDataResponseFormat]
|
||||
responses.extend([resp_type for resp_type in utils.ChartDataResponseType])
|
||||
for response_option in responses:
|
||||
if request.args.get(response_option) == "true":
|
||||
response_type = response_option
|
||||
break
|
||||
|
||||
form_data = get_form_data()[0]
|
||||
|
||||
try:
|
||||
|
|
@ -731,12 +732,10 @@ class Superset(BaseSupersetView):
|
|||
datasource_type=datasource_type,
|
||||
datasource_id=datasource_id,
|
||||
form_data=form_data,
|
||||
force=force,
|
||||
force=request.args.get("force") == "true",
|
||||
)
|
||||
|
||||
return self.generate_json(
|
||||
viz_obj, csv=csv, query=query, results=results, samples=samples
|
||||
)
|
||||
return self.generate_json(viz_obj, response_type)
|
||||
except SupersetException as ex:
|
||||
return json_error_response(utils.error_msg_from_exception(ex))
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,11 @@ from superset import db
|
|||
from superset.charts.schemas import ChartDataQueryContextSchema
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.utils.core import TimeRangeEndpoint
|
||||
from superset.utils.core import (
|
||||
ChartDataResponseFormat,
|
||||
ChartDataResponseType,
|
||||
TimeRangeEndpoint,
|
||||
)
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.fixtures.query_context import get_query_context
|
||||
|
||||
|
|
@ -131,3 +135,55 @@ class QueryContextTests(SupersetTestCase):
|
|||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.granularity, "timecol")
|
||||
self.assertIn("having_druid", query_object.extras)
|
||||
|
||||
def test_csv_response_format(self):
|
||||
"""
|
||||
Ensure that CSV result format works
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["response_format"] = ChartDataResponseFormat.CSV.value
|
||||
payload["queries"][0]["row_limit"] = 10
|
||||
query_context = QueryContext(**payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
data = responses[0]["data"]
|
||||
self.assertIn("name,sum__num\n", data)
|
||||
self.assertEqual(len(data.split("\n")), 12)
|
||||
|
||||
def test_samples_response_type(self):
|
||||
"""
|
||||
Ensure that samples result type works
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["response_type"] = ChartDataResponseType.SAMPLES.value
|
||||
payload["queries"][0]["row_limit"] = 5
|
||||
query_context = QueryContext(**payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
data = responses[0]["data"]
|
||||
self.assertIsInstance(data, list)
|
||||
self.assertEqual(len(data), 5)
|
||||
self.assertNotIn("sum__num", data[0])
|
||||
|
||||
def test_query_response_type(self):
|
||||
"""
|
||||
Ensure that query result type works
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["response_type"] = ChartDataResponseType.QUERY.value
|
||||
query_context = QueryContext(**payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
response = responses[0]
|
||||
self.assertEqual(len(response), 2)
|
||||
self.assertEqual(response["language"], "sql")
|
||||
self.assertIn("SELECT", response["query"])
|
||||
|
|
|
|||
Loading…
Reference in New Issue