feat: Add new result formats and types to chart data API (#9841)

* feat: Add new result formats and types to chart data API

* lint

* Linting

* Add language to query payload

* Fix tests

* simplify tests
This commit is contained in:
Ville Brofeldt 2020-05-20 21:36:14 +03:00 committed by GitHub
parent 368c85de1b
commit a43a1d6303
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 132 additions and 23 deletions

View File

@ -27,7 +27,7 @@ import ExploreChartPanel from './ExploreChartPanel';
import ControlPanelsContainer from './ControlPanelsContainer';
import SaveModal from './SaveModal';
import QueryAndSaveBtns from './QueryAndSaveBtns';
import { getExploreUrl, getExploreLongUrl } from '../exploreUtils';
import { getExploreLongUrl } from '../exploreUtils';
import { areObjectsEqual } from '../../reduxUtils';
import { getFormDataFromControls } from '../controlUtils';
import { chartPropShape } from '../../dashboard/util/propShapes';

View File

@ -711,6 +711,14 @@ class ChartDataQueryContextSchema(Schema):
description="Should the queries be forced to load from the source. "
"Default: `false`",
)
result_type = fields.String(
description="Type of results to return",
validate=validate.OneOf(choices=("query", "results", "samples")),
)
result_format = fields.String(
description="Format of result payload",
validate=validate.OneOf(choices=("json", "csv")),
)
# pylint: disable=no-self-use
@post_load

View File

@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import copy
import logging
import pickle as pkl
from datetime import datetime, timedelta
from typing import Any, ClassVar, Dict, List, Optional
from typing import Any, ClassVar, Dict, List, Optional, Union
import numpy as np
import pandas as pd
@ -49,15 +50,19 @@ class QueryContext:
queries: List[QueryObject]
force: bool
custom_cache_timeout: Optional[int]
response_type: utils.ChartDataResponseType
response_format: utils.ChartDataResponseFormat
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
datasource: Dict[str, Any],
queries: List[Dict[str, Any]],
force: bool = False,
custom_cache_timeout: Optional[int] = None,
response_format: Optional[utils.ChartDataResponseFormat] = None,
response_type: Optional[utils.ChartDataResponseType] = None,
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
@ -65,6 +70,8 @@ class QueryContext:
self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.response_format = response_format or utils.ChartDataResponseFormat.JSON
self.response_type = response_type or utils.ChartDataResponseType.RESULTS
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
"""Returns a pandas dataframe based on the query object"""
@ -124,12 +131,32 @@ class QueryContext:
if dtype.type == np.object_ and col in query_object.metrics:
df[col] = pd.to_numeric(df[col], errors="coerce")
@staticmethod
def get_data(df: pd.DataFrame,) -> List[Dict]: # pylint: disable=no-self-use
def get_data(
self, df: pd.DataFrame,
) -> Union[str, List[Dict[str, Any]]]: # pylint: disable=no-self-use
if self.response_format == utils.ChartDataResponseFormat.CSV:
include_index = not isinstance(df.index, pd.RangeIndex)
result = df.to_csv(index=include_index, **config["CSV_EXPORT"])
return result or ""
return df.to_dict(orient="records")
def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
"""Returns a payload of metadata and data"""
if self.response_type == utils.ChartDataResponseType.QUERY:
return {
"query": self.datasource.get_query_str(query_obj.to_dict()),
"language": self.datasource.query_language,
}
if self.response_type == utils.ChartDataResponseType.SAMPLES:
row_limit = query_obj.row_limit or 1000
query_obj = copy.copy(query_obj)
query_obj.groupby = []
query_obj.metrics = []
query_obj.post_processing = []
query_obj.row_limit = row_limit
query_obj.columns = [o.column_name for o in self.datasource.columns]
payload = self.get_df_payload(query_obj)
df = payload["df"]
status = payload["status"]
@ -142,7 +169,7 @@ class QueryContext:
return payload
def get_payload(self) -> List[Dict[str, Any]]:
"""Get all the payloads from the arrays"""
"""Get all the payloads from the QueryObjects"""
return [self.get_single_payload(query_object) for query_object in self.queries]
@property

View File

@ -1367,3 +1367,22 @@ class FilterOperator(str, Enum):
IN = "IN"
NOT_IN = "NOT IN"
REGEX = "REGEX"
class ChartDataResponseType(str, Enum):
"""
Chart data response type
"""
QUERY = "query"
RESULTS = "results"
SAMPLES = "samples"
class ChartDataResponseFormat(str, Enum):
"""
Chart data response format
"""
CSV = "csv"
JSON = "json"

View File

@ -636,10 +636,8 @@ class Superset(BaseSupersetView):
def get_samples(self, viz_obj):
return self.json_response({"data": viz_obj.get_samples()})
def generate_json(
self, viz_obj, csv=False, query=False, results=False, samples=False
):
if csv:
def generate_json(self, viz_obj, response_type: Optional[str] = None) -> Response:
if response_type == utils.ChartDataResponseFormat.CSV:
return CsvResponse(
viz_obj.get_csv(),
status=200,
@ -647,13 +645,13 @@ class Superset(BaseSupersetView):
mimetype="application/csv",
)
if query:
if response_type == utils.ChartDataResponseType.QUERY:
return self.get_query_string_response(viz_obj)
if results:
if response_type == utils.ChartDataResponseType.RESULTS:
return self.get_raw_results(viz_obj)
if samples:
if response_type == utils.ChartDataResponseType.SAMPLES:
return self.get_samples(viz_obj)
payload = viz_obj.get_payload()
@ -715,11 +713,14 @@ class Superset(BaseSupersetView):
payloads based on the request args in the first block
TODO: break into one endpoint for each return shape"""
csv = request.args.get("csv") == "true"
query = request.args.get("query") == "true"
results = request.args.get("results") == "true"
samples = request.args.get("samples") == "true"
force = request.args.get("force") == "true"
response_type = utils.ChartDataResponseFormat.JSON.value
responses = [resp_format for resp_format in utils.ChartDataResponseFormat]
responses.extend([resp_type for resp_type in utils.ChartDataResponseType])
for response_option in responses:
if request.args.get(response_option) == "true":
response_type = response_option
break
form_data = get_form_data()[0]
try:
@ -731,12 +732,10 @@ class Superset(BaseSupersetView):
datasource_type=datasource_type,
datasource_id=datasource_id,
form_data=form_data,
force=force,
force=request.args.get("force") == "true",
)
return self.generate_json(
viz_obj, csv=csv, query=query, results=results, samples=samples
)
return self.generate_json(viz_obj, response_type)
except SupersetException as ex:
return json_error_response(utils.error_msg_from_exception(ex))

View File

@ -19,7 +19,11 @@ from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import TimeRangeEndpoint
from superset.utils.core import (
ChartDataResponseFormat,
ChartDataResponseType,
TimeRangeEndpoint,
)
from tests.base_tests import SupersetTestCase
from tests.fixtures.query_context import get_query_context
@ -131,3 +135,55 @@ class QueryContextTests(SupersetTestCase):
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
self.assertIn("having_druid", query_object.extras)
def test_csv_response_format(self):
"""
Ensure that CSV result format works
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["response_format"] = ChartDataResponseFormat.CSV.value
payload["queries"][0]["row_limit"] = 10
query_context = QueryContext(**payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)
def test_samples_response_type(self):
"""
Ensure that samples result type works
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["response_type"] = ChartDataResponseType.SAMPLES.value
payload["queries"][0]["row_limit"] = 5
query_context = QueryContext(**payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
self.assertIsInstance(data, list)
self.assertEqual(len(data), 5)
self.assertNotIn("sum__num", data[0])
def test_query_response_type(self):
"""
Ensure that query result type works
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["response_type"] = ChartDataResponseType.QUERY.value
query_context = QueryContext(**payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
response = responses[0]
self.assertEqual(len(response), 2)
self.assertEqual(response["language"], "sql")
self.assertIn("SELECT", response["query"])