fix: Support Jinja template functions in global async queries (#16412)

* Support Jinja template functions in async queries

* Pylint

* Add tests for async tasks

* Remove redundant has_request_context check
This commit is contained in:
Rob DiCiuccio 2021-09-03 04:33:29 -07:00 committed by GitHub
parent a0db5367d2
commit 4e380db3fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 39 deletions

View File

@ -30,7 +30,7 @@ from typing import (
Union,
)
from flask import current_app, g, request
from flask import current_app, g, has_request_context, request
from flask_babel import gettext as _
from jinja2 import DebugUndefined
from jinja2.sandbox import SandboxedEnvironment
@ -172,8 +172,9 @@ class ExtraCache:
# pylint: disable=import-outside-toplevel
from superset.views.utils import get_form_data
if request.args.get(param):
if has_request_context() and request.args.get(param): # type: ignore
return request.args.get(param, default)
form_data, _ = get_form_data()
url_params = form_data.get("url_params") or {}
result = url_params.get(param, default)

View File

@ -46,6 +46,10 @@ def ensure_user_is_set(user_id: Optional[int]) -> None:
g.user = security_manager.get_anonymous_user()
def set_form_data(form_data: Dict[str, Any]) -> None:
g.form_data = form_data
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any],
@ -55,6 +59,7 @@ def load_chart_data_into_cache(
try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
command = ChartDataCommand()
command.set_query_context(form_data)
result = command.run(cache=True)
@ -86,6 +91,7 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
cache_key_prefix = "ejr-" # ejr: explore_json request
try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
# Perform a deep copy here so that below we can cache the original

View File

@ -24,7 +24,7 @@ from urllib import parse
import msgpack
import pyarrow as pa
import simplejson as json
from flask import g, request
from flask import g, has_request_context, request
from flask_appbuilder.security.sqla import models as ab_models
from flask_appbuilder.security.sqla.models import User
from flask_babel import _
@ -130,46 +130,52 @@ def get_form_data( # pylint: disable=too-many-locals
slice_id: Optional[int] = None, use_slice_data: bool = False
) -> Tuple[Dict[str, Any], Optional[Slice]]:
form_data: Dict[str, Any] = {}
# chart data API requests are JSON
request_json_data = (
request.json["queries"][0]
if request.is_json and "queries" in request.json
else None
)
add_sqllab_custom_filters(form_data)
if has_request_context(): # type: ignore
# chart data API requests are JSON
request_json_data = (
request.json["queries"][0]
if request.is_json and "queries" in request.json
else None
)
request_form_data = request.form.get("form_data")
request_args_data = request.args.get("form_data")
if request_json_data:
form_data.update(request_json_data)
if request_form_data:
parsed_form_data = loads_request_json(request_form_data)
# some chart data api requests are form_data
queries = parsed_form_data.get("queries")
if isinstance(queries, list):
form_data.update(queries[0])
else:
form_data.update(parsed_form_data)
# request params can overwrite the body
if request_args_data:
form_data.update(loads_request_json(request_args_data))
add_sqllab_custom_filters(form_data)
# Fallback to using the Flask globals (used for cache warmup) if defined.
request_form_data = request.form.get("form_data")
request_args_data = request.args.get("form_data")
if request_json_data:
form_data.update(request_json_data)
if request_form_data:
parsed_form_data = loads_request_json(request_form_data)
# some chart data api requests are form_data
queries = parsed_form_data.get("queries")
if isinstance(queries, list):
form_data.update(queries[0])
else:
form_data.update(parsed_form_data)
# request params can overwrite the body
if request_args_data:
form_data.update(loads_request_json(request_args_data))
# Fallback to using the Flask globals (used for cache warmup and async queries)
if not form_data and hasattr(g, "form_data"):
form_data = getattr(g, "form_data")
# chart data API requests are JSON
json_data = form_data["queries"][0] if "queries" in form_data else {}
form_data.update(json_data)
url_id = request.args.get("r")
if url_id:
saved_url = db.session.query(models.Url).filter_by(id=url_id).first()
if saved_url:
url_str = parse.unquote_plus(
saved_url.url.split("?")[1][10:], encoding="utf-8"
)
url_form_data = loads_request_json(url_str)
# allow form_date in request override saved url
url_form_data.update(form_data)
form_data = url_form_data
if has_request_context(): # type: ignore
url_id = request.args.get("r")
if url_id:
saved_url = db.session.query(models.Url).filter_by(id=url_id).first()
if saved_url:
url_str = parse.unquote_plus(
saved_url.url.split("?")[1][10:], encoding="utf-8"
)
url_form_data = loads_request_json(url_str)
# allow form_date in request override saved url
url_form_data.update(form_data)
form_data = url_form_data
form_data = {k: v for k, v in form_data.items() if k not in REJECTED_FORM_DATA_KEYS}

View File

@ -45,7 +45,8 @@ from tests.integration_tests.test_app import app
class TestAsyncQueries(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache(self, mock_update_job):
@mock.patch.object(async_queries, "set_form_data")
def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job):
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
@ -63,6 +64,7 @@ class TestAsyncQueries(SupersetTestCase):
load_chart_data_into_cache(job_metadata, query_context)
ensure_user_is_set.assert_called_once_with(user.id)
mock_set_form_data.assert_called_once_with(query_context)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)
@ -154,7 +156,10 @@ class TestAsyncQueries(SupersetTestCase):
)
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache_error(self, mock_update_job):
@mock.patch.object(async_queries, "set_form_data")
def test_load_explore_json_into_cache_error(
self, mock_set_form_data, mock_update_job
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
@ -173,6 +178,7 @@ class TestAsyncQueries(SupersetTestCase):
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id)
mock_set_form_data.assert_called_once_with(form_data)
errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)