diff --git a/superset/charts/api.py b/superset/charts/api.py index 90c353249..dbf37d881 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -549,13 +549,18 @@ class ChartRestApi(BaseSupersetModelRestApi): 500: $ref: '#/components/responses/500' """ + json_body = None if request.is_json: json_body = request.json elif request.form.get("form_data"): # CSV export submits regular form data - json_body = json.loads(request.form["form_data"]) - else: - return self.response_400(message="Request is not JSON") + try: + json_body = json.loads(request.form["form_data"]) + except (TypeError, json.JSONDecodeError): + json_body = None + + if json_body is None: + return self.response_400(message=_("Request is not JSON")) try: command = ChartDataCommand() diff --git a/superset/views/utils.py b/superset/views/utils.py index 0216026f2..9db65a2e0 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -126,6 +126,13 @@ def get_viz( return viz_obj +def loads_request_json(request_json_data: str) -> Dict[Any, Any]: + try: + return json.loads(request_json_data) + except (TypeError, json.JSONDecodeError): + return {} + + def get_form_data( slice_id: Optional[int] = None, use_slice_data: bool = False ) -> Tuple[Dict[str, Any], Optional[Slice]]: @@ -141,10 +148,10 @@ def get_form_data( if request_json_data: form_data.update(request_json_data) if request_form_data: - form_data.update(json.loads(request_form_data)) + form_data.update(loads_request_json(request_form_data)) # request params can overwrite the body if request_args_data: - form_data.update(json.loads(request_args_data)) + form_data.update(loads_request_json(request_args_data)) # Fallback to using the Flask globals (used for cache warmup) if defined. if not form_data and hasattr(g, "form_data"): @@ -157,7 +164,7 @@ def get_form_data( url_str = parse.unquote_plus( saved_url.url.split("?")[1][10:], encoding="utf-8" ) - url_form_data = json.loads(url_str) + url_form_data = loads_request_json(url_str) # allow form_date in request override saved url url_form_data.update(form_data) form_data = url_form_data diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index f7ba39a5c..8e22074fe 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -1175,6 +1175,21 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 400) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_invalid_form_data(self): + """ + Chart data API: Test chart data with invalid form_data json + """ + self.login(username="admin") + data = {"form_data": "NOT VALID JSON"} + + rv = self.client.post( + CHART_DATA_URI, data=data, content_type="multipart/form-data" + ) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 400) + self.assertEqual(response["message"], "Request is not JSON") + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_query_result_type(self): """ @@ -1592,7 +1607,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): assert rv.status_code == 422 assert response == { "message": { - "charts/imported_chart.yaml": "Chart already exists and `overwrite=true` was not passed", + "charts/imported_chart.yaml": "Chart already exists and `overwrite=true` was not passed" } } diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 1c6e6b2e9..fbdf13115 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -925,7 +925,7 @@ class TestUtils(SupersetTestCase): self.assertEqual( form_data, - {"time_range_endpoints": get_time_range_endpoints(form_data={}),}, + {"time_range_endpoints": get_time_range_endpoints(form_data={})}, ) self.assertEqual(slc, None) @@ -994,6 +994,20 @@ class TestUtils(SupersetTestCase): self.assertEqual(slc, None) + def test_get_form_data_corrupted_json(self) -> None: + with app.test_request_context( + data={"form_data": "{x: '2324'}"}, + query_string={"form_data": '{"baz": "bar"'}, + ): + form_data, slc = get_form_data() + + self.assertEqual( + form_data, + {"time_range_endpoints": get_time_range_endpoints(form_data={})}, + ) + + self.assertEqual(slc, None) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_log_this(self) -> None: # TODO: Add additional scenarios.