From f104fba61d7ff3f9688f6c5a9806c2c8fab6fb5e Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 21 Jul 2021 17:03:22 -0700 Subject: [PATCH] feat: add `GET /api/v1/chart/{chart_id}/data/?format{format}` API (#15827) * feat: add `GET /api/v1/chart/{chart_id}/data/?format{format}` API * Fix test --- superset/charts/api.py | 146 +++++++++++++++--- tests/integration_tests/charts/api_tests.py | 120 ++++++++++---- .../security/security_rbac_tests.py | 4 +- 3 files changed, 213 insertions(+), 57 deletions(-) diff --git a/superset/charts/api.py b/superset/charts/api.py index c3658b9f6..0d394cd51 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -106,7 +106,8 @@ class ChartRestApi(BaseSupersetModelRestApi): RouteMethod.IMPORT, RouteMethod.RELATED, "bulk_delete", # not using RouteMethod since locally defined - "data", + "post_data", + "get_data", "data_from_cache", "viz_types", "favorite_status", @@ -516,6 +517,95 @@ class ChartRestApi(BaseSupersetModelRestApi): return self.send_chart_response(result) + @expose("//data/", methods=["GET"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", + log_to_statsd=False, + ) + def get_data(self, pk: int) -> Response: + """ + Takes a chart ID and uses the query context stored when the chart was saved + to return payload data response. + --- + get: + description: >- + Takes a chart ID and uses the query context stored when the chart was saved + to return payload data response. + parameters: + - in: path + schema: + type: integer + name: pk + description: The chart ID + - in: query + name: format + description: The format in which the data should be returned + schema: + type: string + responses: + 200: + description: Query result + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataResponseSchema" + 202: + description: Async job details + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataAsyncResponseSchema" + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 500: + $ref: '#/components/responses/500' + """ + chart = self.datamodel.get(pk, self._base_filters) + if not chart: + return self.response_404() + + try: + json_body = json.loads(chart.query_context) + except (TypeError, json.decoder.JSONDecodeError): + json_body = None + + if json_body is None: + return self.response_400( + message=_( + "Chart has no query context saved. Please save the chart again." + ) + ) + + json_body["result_format"] = request.args.get( + "format", ChartDataResultFormat.JSON + ) + try: + command = ChartDataCommand() + query_context = command.set_query_context(json_body) + command.validate() + except QueryObjectValidationError as error: + return self.response_400(message=error.message) + except ValidationError as error: + return self.response_400( + message=_( + "Request is incorrect: %(error)s", error=error.normalized_messages() + ) + ) + + # TODO: support CSV, SQL query and other non-JSON types + if ( + is_feature_enabled("GLOBAL_ASYNC_QUERIES") + and query_context.result_format == ChartDataResultFormat.JSON + and query_context.result_type == ChartDataResultType.FULL + ): + return self._run_async(command) + + return self.get_data_response(command) + @expose("/data", methods=["POST"]) @protect() @statsd_metrics @@ -523,7 +613,7 @@ class ChartRestApi(BaseSupersetModelRestApi): action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", log_to_statsd=False, ) - def data(self) -> Response: + def post_data(self) -> Response: """ Takes a query context constructed in the client and returns payload data response for the given query. @@ -593,32 +683,38 @@ class ChartRestApi(BaseSupersetModelRestApi): and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): - # First, look for the chart query results in the cache. - try: - result = command.run(force_cached=True) - except ChartDataCacheLoadError: - result = None # type: ignore - - already_cached_result = result is not None - - # If the chart query has already been cached, return it immediately. - if already_cached_result: - return self.send_chart_response(result) - - # Otherwise, kick off a background job to run the chart query. - # Clients will either poll or be notified of query completion, - # at which point they will call the /data/ endpoint - # to retrieve the results. - try: - command.validate_async_request(request) - except AsyncQueryTokenException: - return self.response_401() - - result = command.run_async(g.user.get_id()) - return self.response(202, **result) + return self._run_async(command) return self.get_data_response(command) + def _run_async(self, command: ChartDataCommand) -> Response: + """ + Execute command as an async query. + """ + # First, look for the chart query results in the cache. + try: + result = command.run(force_cached=True) + except ChartDataCacheLoadError: + result = None # type: ignore + + already_cached_result = result is not None + + # If the chart query has already been cached, return it immediately. + if already_cached_result: + return self.send_chart_response(result) + + # Otherwise, kick off a background job to run the chart query. + # Clients will either poll or be notified of query completion, + # at which point they will call the /data/ endpoint + # to retrieve the results. + try: + command.validate_async_request(request) + except AsyncQueryTokenException: + return self.response_401() + + result = command.run_async(g.user.get_id()) + return self.response(202, **result) + @expose("/data/", methods=["GET"]) @protect() @statsd_metrics diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index ad777019e..b99dbe6db 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -186,9 +186,12 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): rv = self.get_assert_metric(uri, "info") data = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 - assert "can_read" in data["permissions"] - assert "can_write" in data["permissions"] - assert len(data["permissions"]) == 2 + assert set(data["permissions"]) == { + "can_get_data", + "can_read", + "can_post_data", + "can_write", + } def create_chart_import(self): buf = BytesIO() @@ -1036,7 +1039,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): Chart API: Test get charts no data access """ self.login(username="gamma") - uri = f"api/v1/chart/" + uri = "api/v1/chart/" rv = self.get_assert_metric(uri, "get_list") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) @@ -1049,12 +1052,69 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): """ self.login(username="admin") request_payload = get_query_context("birth_names") - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) expected_row_count = self.get_expected_row_count("client_id_1") self.assertEqual(data["result"][0]["rowcount"], expected_row_count) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_get_no_query_context(self): + """ + Chart data API: Test GET endpoint when query context is null + """ + self.login(username="admin") + chart = db.session.query(Slice).filter_by(slice_name="Genders").one() + rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") + data = json.loads(rv.data.decode("utf-8")) + assert data == { + "message": "Chart has no query context saved. Please save the chart again." + } + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_get(self): + """ + Chart data API: Test GET endpoint + """ + self.login(username="admin") + chart = db.session.query(Slice).filter_by(slice_name="Genders").one() + chart.query_context = json.dumps( + { + "datasource": {"id": chart.table.id, "type": "table"}, + "force": False, + "queries": [ + { + "time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00", + "granularity": "ds", + "filters": [], + "extras": { + "time_range_endpoints": ["inclusive", "exclusive"], + "having": "", + "having_druid": [], + "where": "", + }, + "applied_time_extras": {}, + "columns": ["gender"], + "metrics": ["sum__num"], + "orderby": [["sum__num", False]], + "annotation_layers": [], + "row_limit": 50000, + "timeseries_limit": 0, + "order_desc": True, + "url_params": {}, + "custom_params": {}, + "custom_form_data": {}, + } + ], + "result_format": "json", + "result_type": "full", + } + ) + rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") + data = json.loads(rv.data.decode("utf-8")) + assert data["result"][0]["status"] == "success" + assert data["result"][0]["rowcount"] == 2 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_applied_time_extras(self): """ @@ -1066,7 +1126,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): "__time_range": "100 years ago : now", "__time_origin": "now", } - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) self.assertEqual( @@ -1095,7 +1155,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): request_payload["queries"][0]["row_limit"] = 5 request_payload["queries"][0]["row_offset"] = 0 request_payload["queries"][0]["orderby"] = [["name", True]] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] self.assertEqual(result["rowcount"], 5) @@ -1108,7 +1168,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): offset = 2 expected_name = result["data"][offset]["name"] request_payload["queries"][0]["row_offset"] = offset - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] self.assertEqual(result["rowcount"], 5) @@ -1125,7 +1185,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.login(username="admin") request_payload = get_query_context("birth_names") del request_payload["queries"][0]["row_limit"] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] self.assertEqual(result["rowcount"], 7) @@ -1142,7 +1202,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): request_payload = get_query_context("birth_names") request_payload["result_type"] = utils.ChartDataResultType.SAMPLES request_payload["queries"][0]["row_limit"] = 10 - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] self.assertEqual(result["rowcount"], 5) @@ -1154,7 +1214,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.login(username="admin") request_payload = get_query_context("birth_names") request_payload["result_type"] = "qwerty" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 400) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -1165,7 +1225,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.login(username="admin") request_payload = get_query_context("birth_names") request_payload["result_format"] = "qwerty" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 400) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -1191,7 +1251,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.login(username="admin") request_payload = get_query_context("birth_names") request_payload["result_type"] = utils.ChartDataResultType.QUERY - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -1202,7 +1262,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.login(username="admin") request_payload = get_query_context("birth_names") request_payload["result_format"] = "csv" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) # Test chart csv without permission @@ -1215,7 +1275,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): request_payload = get_query_context("birth_names") request_payload["result_format"] = "csv" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 403) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -1227,7 +1287,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): request_payload = get_query_context("birth_names") request_payload["queries"][0]["filters"][0]["op"] = "In" request_payload["queries"][0]["row_limit"] = 10 - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] self.assertEqual(result["rowcount"], 10) @@ -1253,7 +1313,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): "op": "!=", "val": ms_epoch, } - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] @@ -1295,7 +1355,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): }, } ] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] @@ -1318,7 +1378,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): {"col": "non_existent_filter", "op": "==", "val": "foo"}, ] request_payload["result_type"] = utils.ChartDataResultType.QUERY - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) response_payload = json.loads(rv.data.decode("utf-8")) assert "non_existent_filter" not in response_payload["result"][0]["query"] @@ -1333,7 +1393,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): request_payload["queries"][0]["filters"] = [ {"col": "gender", "op": "==", "val": "foo"} ] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] @@ -1349,7 +1409,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): request_payload["queries"][0]["filters"] = [] # erroneus WHERE-clause request_payload["queries"][0]["extras"]["where"] = "(gender abc def)" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 400) def test_chart_data_with_invalid_datasource(self): @@ -1359,7 +1419,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.login(username="admin") payload = get_query_context("birth_names") payload["datasource"] = "abc" - rv = self.post_assert_metric(CHART_DATA_URI, payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, payload, "post_data") self.assertEqual(rv.status_code, 400) def test_chart_data_with_invalid_enum_value(self): @@ -1381,7 +1441,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): """ self.login(username="gamma") payload = get_query_context("birth_names") - rv = self.post_assert_metric(CHART_DATA_URI, payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, payload, "post_data") self.assertEqual(rv.status_code, 401) response_payload = json.loads(rv.data.decode("utf-8")) assert ( @@ -1403,7 +1463,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): request_payload["queries"][0]["extras"][ "where" ] = "('boy' = '{{ filter_values('gender', 'xyz' )[0] }}')" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0]["query"] if get_example_database().backend != "presto": @@ -1418,7 +1478,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): async_query_manager.init_app(app) self.login(username="admin") request_payload = get_query_context("birth_names") - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 202) data = json.loads(rv.data.decode("utf-8")) keys = list(data.keys()) @@ -1448,7 +1508,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): ChartDataCommand, "run", return_value=cmd_run_val ) as patched_run: request_payload = get_query_context("birth_names") - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) patched_run.assert_called_once_with(force_cached=True) @@ -1464,7 +1524,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.login(username="admin") request_payload = get_query_context("birth_names") request_payload["result_type"] = "results" - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @@ -1795,7 +1855,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): event["value"] = event_layer.id annotation_layers.append(event) - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) # response should only contain interval and event data, not formula @@ -1839,7 +1899,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): request_payload = get_query_context("birth_names") request_payload["queries"][0]["is_rowcount"] = True request_payload["queries"][0]["groupby"] = ["name"] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] expected_row_count = self.get_expected_row_count("client_id_4") @@ -1856,7 +1916,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): {"result_type": utils.ChartDataResultType.TIMEGRAINS}, {"result_type": utils.ChartDataResultType.COLUMNS}, ] - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") response_payload = json.loads(rv.data.decode("utf-8")) timegrain_result = response_payload["result"][0] column_result = response_payload["result"][1] diff --git a/tests/integration_tests/dashboards/security/security_rbac_tests.py b/tests/integration_tests/dashboards/security/security_rbac_tests.py index c1be5a911..25f71c9f8 100644 --- a/tests/integration_tests/dashboards/security/security_rbac_tests.py +++ b/tests/integration_tests/dashboards/security/security_rbac_tests.py @@ -90,7 +90,7 @@ class TestDashboardRoleBasedSecurity(BaseTestDashboardSecurity): response = self.get_dashboard_view_response(dashboard_to_access) request_payload = get_query_context("birth_names") - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 401) # assert @@ -140,7 +140,7 @@ class TestDashboardRoleBasedSecurity(BaseTestDashboardSecurity): self.assert200(response) request_payload = get_query_context("birth_names") - rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "post_data") self.assertEqual(rv.status_code, 200) # post