diff --git a/superset-frontend/src/explore/components/ExploreViewContainer.jsx b/superset-frontend/src/explore/components/ExploreViewContainer.jsx index 3ff0d3ebd..b7078d898 100644 --- a/superset-frontend/src/explore/components/ExploreViewContainer.jsx +++ b/superset-frontend/src/explore/components/ExploreViewContainer.jsx @@ -27,7 +27,7 @@ import ExploreChartPanel from './ExploreChartPanel'; import ControlPanelsContainer from './ControlPanelsContainer'; import SaveModal from './SaveModal'; import QueryAndSaveBtns from './QueryAndSaveBtns'; -import { getExploreUrl, getExploreLongUrl } from '../exploreUtils'; +import { getExploreLongUrl } from '../exploreUtils'; import { areObjectsEqual } from '../../reduxUtils'; import { getFormDataFromControls } from '../controlUtils'; import { chartPropShape } from '../../dashboard/util/propShapes'; diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 91b71a30a..9d01901da 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -711,6 +711,14 @@ class ChartDataQueryContextSchema(Schema): description="Should the queries be forced to load from the source. " "Default: `false`", ) + result_type = fields.String( + description="Type of results to return", + validate=validate.OneOf(choices=("query", "results", "samples")), + ) + result_format = fields.String( + description="Format of result payload", + validate=validate.OneOf(choices=("json", "csv")), + ) # pylint: disable=no-self-use @post_load diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 3a84ec01f..96fc73338 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import copy import logging import pickle as pkl from datetime import datetime, timedelta -from typing import Any, ClassVar, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional, Union import numpy as np import pandas as pd @@ -49,15 +50,19 @@ class QueryContext: queries: List[QueryObject] force: bool custom_cache_timeout: Optional[int] + response_type: utils.ChartDataResponseType + response_format: utils.ChartDataResponseFormat # TODO: Type datasource and query_object dictionary with TypedDict when it becomes # a vanilla python type https://github.com/python/mypy/issues/5288 - def __init__( + def __init__( # pylint: disable=too-many-arguments self, datasource: Dict[str, Any], queries: List[Dict[str, Any]], force: bool = False, custom_cache_timeout: Optional[int] = None, + response_format: Optional[utils.ChartDataResponseFormat] = None, + response_type: Optional[utils.ChartDataResponseType] = None, ) -> None: self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session @@ -65,6 +70,8 @@ class QueryContext: self.queries = [QueryObject(**query_obj) for query_obj in queries] self.force = force self.custom_cache_timeout = custom_cache_timeout + self.response_format = response_format or utils.ChartDataResponseFormat.JSON + self.response_type = response_type or utils.ChartDataResponseType.RESULTS def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]: """Returns a pandas dataframe based on the query object""" @@ -124,12 +131,32 @@ class QueryContext: if dtype.type == np.object_ and col in query_object.metrics: df[col] = pd.to_numeric(df[col], errors="coerce") - @staticmethod - def get_data(df: pd.DataFrame,) -> List[Dict]: # pylint: disable=no-self-use + def get_data( + self, df: pd.DataFrame, + ) -> Union[str, List[Dict[str, Any]]]: # pylint: disable=no-self-use + if self.response_format == utils.ChartDataResponseFormat.CSV: + include_index = not isinstance(df.index, pd.RangeIndex) + result = df.to_csv(index=include_index, **config["CSV_EXPORT"]) + return result or "" + return df.to_dict(orient="records") def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]: """Returns a payload of metadata and data""" + if self.response_type == utils.ChartDataResponseType.QUERY: + return { + "query": self.datasource.get_query_str(query_obj.to_dict()), + "language": self.datasource.query_language, + } + if self.response_type == utils.ChartDataResponseType.SAMPLES: + row_limit = query_obj.row_limit or 1000 + query_obj = copy.copy(query_obj) + query_obj.groupby = [] + query_obj.metrics = [] + query_obj.post_processing = [] + query_obj.row_limit = row_limit + query_obj.columns = [o.column_name for o in self.datasource.columns] + payload = self.get_df_payload(query_obj) df = payload["df"] status = payload["status"] @@ -142,7 +169,7 @@ class QueryContext: return payload def get_payload(self) -> List[Dict[str, Any]]: - """Get all the payloads from the arrays""" + """Get all the payloads from the QueryObjects""" return [self.get_single_payload(query_object) for query_object in self.queries] @property diff --git a/superset/utils/core.py b/superset/utils/core.py index 47fe8c889..19e39ab08 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1367,3 +1367,22 @@ class FilterOperator(str, Enum): IN = "IN" NOT_IN = "NOT IN" REGEX = "REGEX" + + +class ChartDataResponseType(str, Enum): + """ + Chart data response type + """ + + QUERY = "query" + RESULTS = "results" + SAMPLES = "samples" + + +class ChartDataResponseFormat(str, Enum): + """ + Chart data response format + """ + + CSV = "csv" + JSON = "json" diff --git a/superset/views/core.py b/superset/views/core.py index 25d65699b..10278684f 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -636,10 +636,8 @@ class Superset(BaseSupersetView): def get_samples(self, viz_obj): return self.json_response({"data": viz_obj.get_samples()}) - def generate_json( - self, viz_obj, csv=False, query=False, results=False, samples=False - ): - if csv: + def generate_json(self, viz_obj, response_type: Optional[str] = None) -> Response: + if response_type == utils.ChartDataResponseFormat.CSV: return CsvResponse( viz_obj.get_csv(), status=200, @@ -647,13 +645,13 @@ class Superset(BaseSupersetView): mimetype="application/csv", ) - if query: + if response_type == utils.ChartDataResponseType.QUERY: return self.get_query_string_response(viz_obj) - if results: + if response_type == utils.ChartDataResponseType.RESULTS: return self.get_raw_results(viz_obj) - if samples: + if response_type == utils.ChartDataResponseType.SAMPLES: return self.get_samples(viz_obj) payload = viz_obj.get_payload() @@ -715,11 +713,14 @@ class Superset(BaseSupersetView): payloads based on the request args in the first block TODO: break into one endpoint for each return shape""" - csv = request.args.get("csv") == "true" - query = request.args.get("query") == "true" - results = request.args.get("results") == "true" - samples = request.args.get("samples") == "true" - force = request.args.get("force") == "true" + response_type = utils.ChartDataResponseFormat.JSON.value + responses = [resp_format for resp_format in utils.ChartDataResponseFormat] + responses.extend([resp_type for resp_type in utils.ChartDataResponseType]) + for response_option in responses: + if request.args.get(response_option) == "true": + response_type = response_option + break + form_data = get_form_data()[0] try: @@ -731,12 +732,10 @@ class Superset(BaseSupersetView): datasource_type=datasource_type, datasource_id=datasource_id, form_data=form_data, - force=force, + force=request.args.get("force") == "true", ) - return self.generate_json( - viz_obj, csv=csv, query=query, results=results, samples=samples - ) + return self.generate_json(viz_obj, response_type) except SupersetException as ex: return json_error_response(utils.error_msg_from_exception(ex)) diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index e27c0af75..7d6117ab6 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -19,7 +19,11 @@ from superset import db from superset.charts.schemas import ChartDataQueryContextSchema from superset.common.query_context import QueryContext from superset.connectors.connector_registry import ConnectorRegistry -from superset.utils.core import TimeRangeEndpoint +from superset.utils.core import ( + ChartDataResponseFormat, + ChartDataResponseType, + TimeRangeEndpoint, +) from tests.base_tests import SupersetTestCase from tests.fixtures.query_context import get_query_context @@ -131,3 +135,55 @@ class QueryContextTests(SupersetTestCase): query_object = query_context.queries[0] self.assertEqual(query_object.granularity, "timecol") self.assertIn("having_druid", query_object.extras) + + def test_csv_response_format(self): + """ + Ensure that CSV result format works + """ + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + payload["response_format"] = ChartDataResponseFormat.CSV.value + payload["queries"][0]["row_limit"] = 10 + query_context = QueryContext(**payload) + responses = query_context.get_payload() + self.assertEqual(len(responses), 1) + data = responses[0]["data"] + self.assertIn("name,sum__num\n", data) + self.assertEqual(len(data.split("\n")), 12) + + def test_samples_response_type(self): + """ + Ensure that samples result type works + """ + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + payload["response_type"] = ChartDataResponseType.SAMPLES.value + payload["queries"][0]["row_limit"] = 5 + query_context = QueryContext(**payload) + responses = query_context.get_payload() + self.assertEqual(len(responses), 1) + data = responses[0]["data"] + self.assertIsInstance(data, list) + self.assertEqual(len(data), 5) + self.assertNotIn("sum__num", data[0]) + + def test_query_response_type(self): + """ + Ensure that query result type works + """ + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + payload["response_type"] = ChartDataResponseType.QUERY.value + query_context = QueryContext(**payload) + responses = query_context.get_payload() + self.assertEqual(len(responses), 1) + response = responses[0] + self.assertEqual(len(response), 2) + self.assertEqual(response["language"], "sql") + self.assertIn("SELECT", response["query"])