fix: Support Jinja template functions in global async queries (#16412)
* Support Jinja template functions in async queries * Pylint * Add tests for async tasks * Remove redundant has_request_context check
This commit is contained in:
parent
a0db5367d2
commit
4e380db3fd
|
|
@ -30,7 +30,7 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
from flask import current_app, g, request
|
||||
from flask import current_app, g, has_request_context, request
|
||||
from flask_babel import gettext as _
|
||||
from jinja2 import DebugUndefined
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
|
|
@ -172,8 +172,9 @@ class ExtraCache:
|
|||
# pylint: disable=import-outside-toplevel
|
||||
from superset.views.utils import get_form_data
|
||||
|
||||
if request.args.get(param):
|
||||
if has_request_context() and request.args.get(param): # type: ignore
|
||||
return request.args.get(param, default)
|
||||
|
||||
form_data, _ = get_form_data()
|
||||
url_params = form_data.get("url_params") or {}
|
||||
result = url_params.get(param, default)
|
||||
|
|
|
|||
|
|
@ -46,6 +46,10 @@ def ensure_user_is_set(user_id: Optional[int]) -> None:
|
|||
g.user = security_manager.get_anonymous_user()
|
||||
|
||||
|
||||
def set_form_data(form_data: Dict[str, Any]) -> None:
|
||||
g.form_data = form_data
|
||||
|
||||
|
||||
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
|
||||
def load_chart_data_into_cache(
|
||||
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
|
||||
|
|
@ -55,6 +59,7 @@ def load_chart_data_into_cache(
|
|||
|
||||
try:
|
||||
ensure_user_is_set(job_metadata.get("user_id"))
|
||||
set_form_data(form_data)
|
||||
command = ChartDataCommand()
|
||||
command.set_query_context(form_data)
|
||||
result = command.run(cache=True)
|
||||
|
|
@ -86,6 +91,7 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
|
|||
cache_key_prefix = "ejr-" # ejr: explore_json request
|
||||
try:
|
||||
ensure_user_is_set(job_metadata.get("user_id"))
|
||||
set_form_data(form_data)
|
||||
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
|
||||
|
||||
# Perform a deep copy here so that below we can cache the original
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from urllib import parse
|
|||
import msgpack
|
||||
import pyarrow as pa
|
||||
import simplejson as json
|
||||
from flask import g, request
|
||||
from flask import g, has_request_context, request
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from flask_babel import _
|
||||
|
|
@ -130,46 +130,52 @@ def get_form_data( # pylint: disable=too-many-locals
|
|||
slice_id: Optional[int] = None, use_slice_data: bool = False
|
||||
) -> Tuple[Dict[str, Any], Optional[Slice]]:
|
||||
form_data: Dict[str, Any] = {}
|
||||
# chart data API requests are JSON
|
||||
request_json_data = (
|
||||
request.json["queries"][0]
|
||||
if request.is_json and "queries" in request.json
|
||||
else None
|
||||
)
|
||||
|
||||
add_sqllab_custom_filters(form_data)
|
||||
if has_request_context(): # type: ignore
|
||||
# chart data API requests are JSON
|
||||
request_json_data = (
|
||||
request.json["queries"][0]
|
||||
if request.is_json and "queries" in request.json
|
||||
else None
|
||||
)
|
||||
|
||||
request_form_data = request.form.get("form_data")
|
||||
request_args_data = request.args.get("form_data")
|
||||
if request_json_data:
|
||||
form_data.update(request_json_data)
|
||||
if request_form_data:
|
||||
parsed_form_data = loads_request_json(request_form_data)
|
||||
# some chart data api requests are form_data
|
||||
queries = parsed_form_data.get("queries")
|
||||
if isinstance(queries, list):
|
||||
form_data.update(queries[0])
|
||||
else:
|
||||
form_data.update(parsed_form_data)
|
||||
# request params can overwrite the body
|
||||
if request_args_data:
|
||||
form_data.update(loads_request_json(request_args_data))
|
||||
add_sqllab_custom_filters(form_data)
|
||||
|
||||
# Fallback to using the Flask globals (used for cache warmup) if defined.
|
||||
request_form_data = request.form.get("form_data")
|
||||
request_args_data = request.args.get("form_data")
|
||||
if request_json_data:
|
||||
form_data.update(request_json_data)
|
||||
if request_form_data:
|
||||
parsed_form_data = loads_request_json(request_form_data)
|
||||
# some chart data api requests are form_data
|
||||
queries = parsed_form_data.get("queries")
|
||||
if isinstance(queries, list):
|
||||
form_data.update(queries[0])
|
||||
else:
|
||||
form_data.update(parsed_form_data)
|
||||
# request params can overwrite the body
|
||||
if request_args_data:
|
||||
form_data.update(loads_request_json(request_args_data))
|
||||
|
||||
# Fallback to using the Flask globals (used for cache warmup and async queries)
|
||||
if not form_data and hasattr(g, "form_data"):
|
||||
form_data = getattr(g, "form_data")
|
||||
# chart data API requests are JSON
|
||||
json_data = form_data["queries"][0] if "queries" in form_data else {}
|
||||
form_data.update(json_data)
|
||||
|
||||
url_id = request.args.get("r")
|
||||
if url_id:
|
||||
saved_url = db.session.query(models.Url).filter_by(id=url_id).first()
|
||||
if saved_url:
|
||||
url_str = parse.unquote_plus(
|
||||
saved_url.url.split("?")[1][10:], encoding="utf-8"
|
||||
)
|
||||
url_form_data = loads_request_json(url_str)
|
||||
# allow form_date in request override saved url
|
||||
url_form_data.update(form_data)
|
||||
form_data = url_form_data
|
||||
if has_request_context(): # type: ignore
|
||||
url_id = request.args.get("r")
|
||||
if url_id:
|
||||
saved_url = db.session.query(models.Url).filter_by(id=url_id).first()
|
||||
if saved_url:
|
||||
url_str = parse.unquote_plus(
|
||||
saved_url.url.split("?")[1][10:], encoding="utf-8"
|
||||
)
|
||||
url_form_data = loads_request_json(url_str)
|
||||
# allow form_date in request override saved url
|
||||
url_form_data.update(form_data)
|
||||
form_data = url_form_data
|
||||
|
||||
form_data = {k: v for k, v in form_data.items() if k not in REJECTED_FORM_DATA_KEYS}
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,8 @@ from tests.integration_tests.test_app import app
|
|||
class TestAsyncQueries(SupersetTestCase):
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_chart_data_into_cache(self, mock_update_job):
|
||||
@mock.patch.object(async_queries, "set_form_data")
|
||||
def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job):
|
||||
async_query_manager.init_app(app)
|
||||
query_context = get_query_context("birth_names")
|
||||
user = security_manager.find_user("gamma")
|
||||
|
|
@ -63,6 +64,7 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
load_chart_data_into_cache(job_metadata, query_context)
|
||||
|
||||
ensure_user_is_set.assert_called_once_with(user.id)
|
||||
mock_set_form_data.assert_called_once_with(query_context)
|
||||
mock_update_job.assert_called_once_with(
|
||||
job_metadata, "done", result_url=mock.ANY
|
||||
)
|
||||
|
|
@ -154,7 +156,10 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
)
|
||||
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_explore_json_into_cache_error(self, mock_update_job):
|
||||
@mock.patch.object(async_queries, "set_form_data")
|
||||
def test_load_explore_json_into_cache_error(
|
||||
self, mock_set_form_data, mock_update_job
|
||||
):
|
||||
async_query_manager.init_app(app)
|
||||
user = security_manager.find_user("gamma")
|
||||
form_data = {}
|
||||
|
|
@ -173,6 +178,7 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
load_explore_json_into_cache(job_metadata, form_data)
|
||||
ensure_user_is_set.assert_called_once_with(user.id)
|
||||
|
||||
mock_set_form_data.assert_called_once_with(form_data)
|
||||
errors = ["The dataset associated with this chart no longer exists"]
|
||||
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue