diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 37703339e..534101bae 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -28,11 +28,14 @@ from marshmallow import ValidationError from superset import is_feature_enabled, security_manager from superset.charts.api import ChartRestApi -from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.exceptions import ( ChartDataCacheLoadError, ChartDataQueryFailedError, ) +from superset.charts.data.commands import ( + ChartDataCommand, + CreateAsyncChartDataJobCommand, +) from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader from superset.charts.post_processing import apply_post_process from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType @@ -145,7 +148,7 @@ class ChartDataRestApi(ChartRestApi): and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): - return self._run_async(command) + return self._run_async(json_body, command) try: form_data = json.loads(chart.params) @@ -231,7 +234,7 @@ class ChartDataRestApi(ChartRestApi): and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): - return self._run_async(command) + return self._run_async(json_body, command) return self._get_data_response(command) @@ -289,7 +292,9 @@ class ChartDataRestApi(ChartRestApi): return self._get_data_response(command, True) - def _run_async(self, command: ChartDataCommand) -> Response: + def _run_async( + self, form_data: Dict[str, Any], command: ChartDataCommand + ) -> Response: """ Execute command as an async query. """ @@ -309,12 +314,13 @@ class ChartDataRestApi(ChartRestApi): # Clients will either poll or be notified of query completion, # at which point they will call the /data/ endpoint # to retrieve the results. + async_command = CreateAsyncChartDataJobCommand() try: - command.validate_async_request(request) + async_command.validate(request) except AsyncQueryTokenException: return self.response_401() - result = command.run_async(g.user.get_id()) + result = async_command.run(form_data, g.user.get_id()) return self.response(202, **result) def _send_chart_response( diff --git a/superset/charts/commands/data.py b/superset/charts/data/commands.py similarity index 88% rename from superset/charts/commands/data.py rename to superset/charts/data/commands.py index ec63362a5..d434f79a1 100644 --- a/superset/charts/commands/data.py +++ b/superset/charts/data/commands.py @@ -35,10 +35,7 @@ logger = logging.getLogger(__name__) class ChartDataCommand(BaseCommand): - def __init__(self) -> None: - self._form_data: Dict[str, Any] - self._query_context: QueryContext - self._async_channel_id: str + _query_context: QueryContext def run(self, **kwargs: Any) -> Dict[str, Any]: # caching is handled in query_context.get_df_payload @@ -66,26 +63,27 @@ class ChartDataCommand(BaseCommand): return return_value - def run_async(self, user_id: Optional[str]) -> Dict[str, Any]: - job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) - load_chart_data_into_cache.delay(job_metadata, self._form_data) - - return job_metadata - def set_query_context(self, form_data: Dict[str, Any]) -> QueryContext: - self._form_data = form_data try: - self._query_context = ChartDataQueryContextSchema().load(self._form_data) + self._query_context = ChartDataQueryContextSchema().load(form_data) except KeyError as ex: raise ValidationError("Request is incorrect") from ex except ValidationError as error: raise error - return self._query_context def validate(self) -> None: self._query_context.raise_for_access() - def validate_async_request(self, request: Request) -> None: + +class CreateAsyncChartDataJobCommand: + _async_channel_id: str + + def validate(self, request: Request) -> None: jwt_data = async_query_manager.parse_jwt_from_request(request) self._async_channel_id = jwt_data["channel"] + + def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]: + job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) + load_chart_data_into_cache.delay(job_metadata, form_data) + return job_metadata diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 18094323e..c50dbb9a9 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -55,7 +55,7 @@ def load_chart_data_into_cache( job_metadata: Dict[str, Any], form_data: Dict[str, Any], ) -> None: # pylint: disable=import-outside-toplevel - from superset.charts.commands.data import ChartDataCommand + from superset.charts.data.commands import ChartDataCommand try: ensure_user_is_set(job_metadata.get("user_id")) diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 45d300b73..1b2ade28f 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -37,7 +37,7 @@ from tests.integration_tests.test_app import app import pytest -from superset.charts.commands.data import ChartDataCommand +from superset.charts.data.commands import ChartDataCommand from superset.connectors.sqla.models import TableColumn, SqlaTable from superset.errors import SupersetErrorType from superset.extensions import async_query_manager, db diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 3ea1c6f0c..e2cf21c55 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -22,10 +22,8 @@ import pytest from celery.exceptions import SoftTimeLimitExceeded from flask import g -from superset import db -from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.exceptions import ChartDataQueryFailedError -from superset.connectors.sqla.models import SqlaTable +from superset.charts.data.commands import ChartDataCommand from superset.exceptions import SupersetException from superset.extensions import async_query_manager, security_manager from superset.tasks import async_queries