fix(async queries): Remove "force" param on cached data retrieval (#12103)

* Async queries: remove force cache param on data retrieval

* Assert equal query_object cache keys

* Decouple etag_cache from permission checks

* Fix query_context test

* Use marshmallow EnumField for validation
This commit is contained in:
Rob DiCiuccio 2021-01-27 10:16:57 -08:00 committed by GitHub
parent 044d1ae3a3
commit d7cbd53fce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 74 additions and 39 deletions

View File

@ -30,7 +30,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

View File

@ -19,11 +19,14 @@ from typing import Any, Dict
from flask_babel import gettext as _
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from marshmallow.validate import Length, Range
from marshmallow_enum import EnumField
from superset.common.query_context import QueryContext
from superset.utils import schema as utils
from superset.utils.core import (
AnnotationType,
ChartDataResultFormat,
ChartDataResultType,
FilterOperator,
PostProcessingBoxplotWhiskerType,
PostProcessingContributionOrientation,
@ -1012,14 +1015,9 @@ class ChartDataQueryContextSchema(Schema):
description="Should the queries be forced to load from the source. "
"Default: `false`",
)
result_type = fields.String(
description="Type of results to return",
validate=validate.OneOf(choices=("full", "query", "results", "samples")),
)
result_format = fields.String(
description="Format of result payload",
validate=validate.OneOf(choices=("json", "csv")),
)
result_type = EnumField(ChartDataResultType, by_value=True)
result_format = EnumField(ChartDataResultFormat, by_value=True)
# pylint: disable=no-self-use,unused-argument
@post_load

View File

@ -85,9 +85,8 @@ class QueryContext:
self.cache_values = {
"datasource": datasource,
"queries": queries,
"force": force,
"result_type": result_type,
"result_format": result_format,
"result_type": self.result_type,
"result_format": self.result_format,
}
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:

View File

@ -125,9 +125,7 @@ def memoized_func(
def etag_cache(
check_perms: Callable[..., Any],
cache: Cache = cache_manager.cache,
max_age: Optional[Union[int, float]] = None,
cache: Cache = cache_manager.cache, max_age: Optional[Union[int, float]] = None,
) -> Callable[..., Any]:
"""
A decorator for caching views and handling etag conditional requests.
@ -147,9 +145,6 @@ def etag_cache(
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
# check if the user can access the resource
check_perms(*args, **kwargs)
# for POST requests we can't set cache headers, use the response
# cache nor use conditional requests; this will still use the
# dataframe cache in `superset/viz.py`, though.

View File

@ -127,6 +127,7 @@ from superset.views.utils import (
bootstrap_user_data,
check_datasource_perms,
check_explore_cache_perms,
check_resource_permissions,
check_slice_perms,
get_cta_schema_name,
get_dashboard_extra_filters,
@ -458,7 +459,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@api
@has_access_api
@expose("/slice_json/<int:slice_id>")
@etag_cache(check_perms=check_slice_perms)
@etag_cache()
@check_resource_permissions(check_slice_perms)
def slice_json(self, slice_id: int) -> FlaskResponse:
form_data, slc = get_form_data(slice_id, use_slice_data=True)
if not slc:
@ -510,7 +512,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@handle_api_exception
@permission_name("explore_json")
@expose("/explore_json/data/<cache_key>", methods=["GET"])
@etag_cache(check_perms=check_explore_cache_perms)
@check_resource_permissions(check_explore_cache_perms)
def explore_json_data(self, cache_key: str) -> FlaskResponse:
"""Serves cached result data for async explore_json calls
@ -554,7 +556,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
methods=EXPLORE_JSON_METHODS,
)
@expose("/explore_json/", methods=EXPLORE_JSON_METHODS)
@etag_cache(check_perms=check_datasource_perms)
@etag_cache()
@check_resource_permissions(check_datasource_perms)
def explore_json(
self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None
) -> FlaskResponse:

View File

@ -17,6 +17,7 @@
import logging
from collections import defaultdict
from datetime import date
from functools import wraps
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union
from urllib import parse
@ -437,6 +438,23 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool:
return obj and user in obj.owners
def check_resource_permissions(check_perms: Callable[..., Any],) -> Callable[..., Any]:
"""
A decorator for checking permissions on a request using the passed-in function.
"""
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> None:
# check if the user can access the resource
check_perms(*args, **kwargs)
return f(*args, **kwargs)
return wrapper
return decorator
def check_explore_cache_perms(_self: Any, cache_key: str) -> None:
"""
Loads async explore_json request data from cache and performs access check

View File

@ -143,13 +143,6 @@ class BaseViz:
self._force_cached = force_cached
self.from_dttm: Optional[datetime] = None
self.to_dttm: Optional[datetime] = None
# Keeping track of whether some data came from cache
# this is useful to trigger the <CachedLabel /> when
# in the cases where visualization have many queries
# (FilterBox for instance)
self._any_cache_key: Optional[str] = None
self._any_cached_dttm: Optional[str] = None
self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = []
self.process_metrics()
@ -496,6 +489,7 @@ class BaseViz:
if not query_obj:
query_obj = self.query_obj()
cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None
cache_value = None
logger.info("Cache key: {}".format(cache_key))
is_loaded = False
stacktrace = None
@ -507,8 +501,6 @@ class BaseViz:
try:
df = cache_value["df"]
self.query = cache_value["query"]
self._any_cached_dttm = cache_value["dttm"]
self._any_cache_key = cache_key
self.status = utils.QueryStatus.SUCCESS
is_loaded = True
stats_logger.incr("loaded_from_cache")
@ -583,13 +575,13 @@ class BaseViz:
self.datasource.uid,
)
return {
"cache_key": self._any_cache_key,
"cached_dttm": self._any_cached_dttm,
"cache_key": cache_key,
"cached_dttm": cache_value["dttm"] if cache_value is not None else None,
"cache_timeout": self.cache_timeout,
"df": df,
"errors": self.errors,
"form_data": self.form_data,
"is_cached": self._any_cache_key is not None,
"is_cached": cache_value is not None,
"query": self.query,
"from_dttm": self.from_dttm,
"to_dttm": self.to_dttm,

View File

@ -19,6 +19,8 @@ import pytest
from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import cache_manager
from superset.models.cache import CacheKey
from superset.utils.core import (
AdhocMetricExpressionType,
ChartDataResultFormat,
@ -68,11 +70,39 @@ class TestQueryContext(SupersetTestCase):
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"])
def test_cache_key_changes_when_datasource_is_updated(self):
def test_cache(self):
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id)
payload["force"] = True
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
query_cache_key = query_context.query_cache_key(query_object)
response = query_context.get_payload(cache_query_context=True)
cache_key = response["cache_key"]
assert cache_key is not None
cached = cache_manager.cache.get(cache_key)
assert cached is not None
rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
rehydrated_qo = rehydrated_qc.queries[0]
rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo)
self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
self.assertEqual(len(rehydrated_qc.queries), 1)
self.assertEqual(query_cache_key, rehydrated_query_cache_key)
self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
self.assertEqual(rehydrated_qc.result_format, query_context.result_format)
self.assertFalse(rehydrated_qc.force)
def test_query_cache_key_changes_when_datasource_is_updated(self):
self.login(username="admin")
payload = get_query_context("birth_names")
# construct baseline cache_key
# construct baseline query_cache_key
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.query_cache_key(query_object)
@ -89,7 +119,7 @@ class TestQueryContext(SupersetTestCase):
datasource.description = description_original
db.session.commit()
# create new QueryContext with unchanged attributes and extract new cache_key
# create new QueryContext with unchanged attributes, extract new query_cache_key
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_new = query_context.query_cache_key(query_object)
@ -97,16 +127,16 @@ class TestQueryContext(SupersetTestCase):
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
def test_cache_key_changes_when_post_processing_is_updated(self):
def test_query_cache_key_changes_when_post_processing_is_updated(self):
self.login(username="admin")
payload = get_query_context("birth_names", add_postprocessing_operations=True)
# construct baseline cache_key from query_context with post processing operation
# construct baseline query_cache_key from query_context with post processing operation
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.query_cache_key(query_object)
# ensure added None post_processing operation doesn't change cache_key
# ensure added None post_processing operation doesn't change query_cache_key
payload["queries"][0]["post_processing"].append(None)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]