fix(async queries): Remove "force" param on cached data retrieval (#12103)

* Async queries: remove force cache param on data retrieval

* Assert equal query_object cache keys

* Decouple etag_cache from permission checks

* Fix query_context test

* Use marshmallow EnumField for validation
This commit is contained in:
Rob DiCiuccio 2021-01-27 10:16:57 -08:00 committed by GitHub
parent 044d1ae3a3
commit d7cbd53fce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 74 additions and 39 deletions

View File

@ -30,7 +30,7 @@ combine_as_imports = true
include_trailing_comma = true include_trailing_comma = true
line_length = 88 line_length = 88
known_first_party = superset known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3 multi_line_output = 3
order_by_type = false order_by_type = false

View File

@ -19,11 +19,14 @@ from typing import Any, Dict
from flask_babel import gettext as _ from flask_babel import gettext as _
from marshmallow import EXCLUDE, fields, post_load, Schema, validate from marshmallow import EXCLUDE, fields, post_load, Schema, validate
from marshmallow.validate import Length, Range from marshmallow.validate import Length, Range
from marshmallow_enum import EnumField
from superset.common.query_context import QueryContext from superset.common.query_context import QueryContext
from superset.utils import schema as utils from superset.utils import schema as utils
from superset.utils.core import ( from superset.utils.core import (
AnnotationType, AnnotationType,
ChartDataResultFormat,
ChartDataResultType,
FilterOperator, FilterOperator,
PostProcessingBoxplotWhiskerType, PostProcessingBoxplotWhiskerType,
PostProcessingContributionOrientation, PostProcessingContributionOrientation,
@ -1012,14 +1015,9 @@ class ChartDataQueryContextSchema(Schema):
description="Should the queries be forced to load from the source. " description="Should the queries be forced to load from the source. "
"Default: `false`", "Default: `false`",
) )
result_type = fields.String(
description="Type of results to return", result_type = EnumField(ChartDataResultType, by_value=True)
validate=validate.OneOf(choices=("full", "query", "results", "samples")), result_format = EnumField(ChartDataResultFormat, by_value=True)
)
result_format = fields.String(
description="Format of result payload",
validate=validate.OneOf(choices=("json", "csv")),
)
# pylint: disable=no-self-use,unused-argument # pylint: disable=no-self-use,unused-argument
@post_load @post_load

View File

@ -85,9 +85,8 @@ class QueryContext:
self.cache_values = { self.cache_values = {
"datasource": datasource, "datasource": datasource,
"queries": queries, "queries": queries,
"force": force, "result_type": self.result_type,
"result_type": result_type, "result_format": self.result_format,
"result_format": result_format,
} }
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]: def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:

View File

@ -125,9 +125,7 @@ def memoized_func(
def etag_cache( def etag_cache(
check_perms: Callable[..., Any], cache: Cache = cache_manager.cache, max_age: Optional[Union[int, float]] = None,
cache: Cache = cache_manager.cache,
max_age: Optional[Union[int, float]] = None,
) -> Callable[..., Any]: ) -> Callable[..., Any]:
""" """
A decorator for caching views and handling etag conditional requests. A decorator for caching views and handling etag conditional requests.
@ -147,9 +145,6 @@ def etag_cache(
def decorator(f: Callable[..., Any]) -> Callable[..., Any]: def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f) @wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin: def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
# check if the user can access the resource
check_perms(*args, **kwargs)
# for POST requests we can't set cache headers, use the response # for POST requests we can't set cache headers, use the response
# cache nor use conditional requests; this will still use the # cache nor use conditional requests; this will still use the
# dataframe cache in `superset/viz.py`, though. # dataframe cache in `superset/viz.py`, though.

View File

@ -127,6 +127,7 @@ from superset.views.utils import (
bootstrap_user_data, bootstrap_user_data,
check_datasource_perms, check_datasource_perms,
check_explore_cache_perms, check_explore_cache_perms,
check_resource_permissions,
check_slice_perms, check_slice_perms,
get_cta_schema_name, get_cta_schema_name,
get_dashboard_extra_filters, get_dashboard_extra_filters,
@ -458,7 +459,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@api @api
@has_access_api @has_access_api
@expose("/slice_json/<int:slice_id>") @expose("/slice_json/<int:slice_id>")
@etag_cache(check_perms=check_slice_perms) @etag_cache()
@check_resource_permissions(check_slice_perms)
def slice_json(self, slice_id: int) -> FlaskResponse: def slice_json(self, slice_id: int) -> FlaskResponse:
form_data, slc = get_form_data(slice_id, use_slice_data=True) form_data, slc = get_form_data(slice_id, use_slice_data=True)
if not slc: if not slc:
@ -510,7 +512,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@handle_api_exception @handle_api_exception
@permission_name("explore_json") @permission_name("explore_json")
@expose("/explore_json/data/<cache_key>", methods=["GET"]) @expose("/explore_json/data/<cache_key>", methods=["GET"])
@etag_cache(check_perms=check_explore_cache_perms) @check_resource_permissions(check_explore_cache_perms)
def explore_json_data(self, cache_key: str) -> FlaskResponse: def explore_json_data(self, cache_key: str) -> FlaskResponse:
"""Serves cached result data for async explore_json calls """Serves cached result data for async explore_json calls
@ -554,7 +556,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
methods=EXPLORE_JSON_METHODS, methods=EXPLORE_JSON_METHODS,
) )
@expose("/explore_json/", methods=EXPLORE_JSON_METHODS) @expose("/explore_json/", methods=EXPLORE_JSON_METHODS)
@etag_cache(check_perms=check_datasource_perms) @etag_cache()
@check_resource_permissions(check_datasource_perms)
def explore_json( def explore_json(
self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None
) -> FlaskResponse: ) -> FlaskResponse:

View File

@ -17,6 +17,7 @@
import logging import logging
from collections import defaultdict from collections import defaultdict
from datetime import date from datetime import date
from functools import wraps
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union
from urllib import parse from urllib import parse
@ -437,6 +438,23 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool:
return obj and user in obj.owners return obj and user in obj.owners
def check_resource_permissions(check_perms: Callable[..., Any],) -> Callable[..., Any]:
"""
A decorator for checking permissions on a request using the passed-in function.
"""
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> None:
# check if the user can access the resource
check_perms(*args, **kwargs)
return f(*args, **kwargs)
return wrapper
return decorator
def check_explore_cache_perms(_self: Any, cache_key: str) -> None: def check_explore_cache_perms(_self: Any, cache_key: str) -> None:
""" """
Loads async explore_json request data from cache and performs access check Loads async explore_json request data from cache and performs access check

View File

@ -143,13 +143,6 @@ class BaseViz:
self._force_cached = force_cached self._force_cached = force_cached
self.from_dttm: Optional[datetime] = None self.from_dttm: Optional[datetime] = None
self.to_dttm: Optional[datetime] = None self.to_dttm: Optional[datetime] = None
# Keeping track of whether some data came from cache
# this is useful to trigger the <CachedLabel /> when
# in the cases where visualization have many queries
# (FilterBox for instance)
self._any_cache_key: Optional[str] = None
self._any_cached_dttm: Optional[str] = None
self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = [] self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = []
self.process_metrics() self.process_metrics()
@ -496,6 +489,7 @@ class BaseViz:
if not query_obj: if not query_obj:
query_obj = self.query_obj() query_obj = self.query_obj()
cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None
cache_value = None
logger.info("Cache key: {}".format(cache_key)) logger.info("Cache key: {}".format(cache_key))
is_loaded = False is_loaded = False
stacktrace = None stacktrace = None
@ -507,8 +501,6 @@ class BaseViz:
try: try:
df = cache_value["df"] df = cache_value["df"]
self.query = cache_value["query"] self.query = cache_value["query"]
self._any_cached_dttm = cache_value["dttm"]
self._any_cache_key = cache_key
self.status = utils.QueryStatus.SUCCESS self.status = utils.QueryStatus.SUCCESS
is_loaded = True is_loaded = True
stats_logger.incr("loaded_from_cache") stats_logger.incr("loaded_from_cache")
@ -583,13 +575,13 @@ class BaseViz:
self.datasource.uid, self.datasource.uid,
) )
return { return {
"cache_key": self._any_cache_key, "cache_key": cache_key,
"cached_dttm": self._any_cached_dttm, "cached_dttm": cache_value["dttm"] if cache_value is not None else None,
"cache_timeout": self.cache_timeout, "cache_timeout": self.cache_timeout,
"df": df, "df": df,
"errors": self.errors, "errors": self.errors,
"form_data": self.form_data, "form_data": self.form_data,
"is_cached": self._any_cache_key is not None, "is_cached": cache_value is not None,
"query": self.query, "query": self.query,
"from_dttm": self.from_dttm, "from_dttm": self.from_dttm,
"to_dttm": self.to_dttm, "to_dttm": self.to_dttm,

View File

@ -19,6 +19,8 @@ import pytest
from superset import db from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema from superset.charts.schemas import ChartDataQueryContextSchema
from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import cache_manager
from superset.models.cache import CacheKey
from superset.utils.core import ( from superset.utils.core import (
AdhocMetricExpressionType, AdhocMetricExpressionType,
ChartDataResultFormat, ChartDataResultFormat,
@ -68,11 +70,39 @@ class TestQueryContext(SupersetTestCase):
self.assertEqual(post_proc["operation"], payload_post_proc["operation"]) self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"]) self.assertEqual(post_proc["options"], payload_post_proc["options"])
def test_cache_key_changes_when_datasource_is_updated(self): def test_cache(self):
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id)
payload["force"] = True
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
query_cache_key = query_context.query_cache_key(query_object)
response = query_context.get_payload(cache_query_context=True)
cache_key = response["cache_key"]
assert cache_key is not None
cached = cache_manager.cache.get(cache_key)
assert cached is not None
rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
rehydrated_qo = rehydrated_qc.queries[0]
rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo)
self.assertEqual(rehydrated_qc.datasource, query_context.datasource)
self.assertEqual(len(rehydrated_qc.queries), 1)
self.assertEqual(query_cache_key, rehydrated_query_cache_key)
self.assertEqual(rehydrated_qc.result_type, query_context.result_type)
self.assertEqual(rehydrated_qc.result_format, query_context.result_format)
self.assertFalse(rehydrated_qc.force)
def test_query_cache_key_changes_when_datasource_is_updated(self):
self.login(username="admin") self.login(username="admin")
payload = get_query_context("birth_names") payload = get_query_context("birth_names")
# construct baseline cache_key # construct baseline query_cache_key
query_context = ChartDataQueryContextSchema().load(payload) query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0] query_object = query_context.queries[0]
cache_key_original = query_context.query_cache_key(query_object) cache_key_original = query_context.query_cache_key(query_object)
@ -89,7 +119,7 @@ class TestQueryContext(SupersetTestCase):
datasource.description = description_original datasource.description = description_original
db.session.commit() db.session.commit()
# create new QueryContext with unchanged attributes and extract new cache_key # create new QueryContext with unchanged attributes, extract new query_cache_key
query_context = ChartDataQueryContextSchema().load(payload) query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0] query_object = query_context.queries[0]
cache_key_new = query_context.query_cache_key(query_object) cache_key_new = query_context.query_cache_key(query_object)
@ -97,16 +127,16 @@ class TestQueryContext(SupersetTestCase):
# the new cache_key should be different due to updated datasource # the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new) self.assertNotEqual(cache_key_original, cache_key_new)
def test_cache_key_changes_when_post_processing_is_updated(self): def test_query_cache_key_changes_when_post_processing_is_updated(self):
self.login(username="admin") self.login(username="admin")
payload = get_query_context("birth_names", add_postprocessing_operations=True) payload = get_query_context("birth_names", add_postprocessing_operations=True)
# construct baseline cache_key from query_context with post processing operation # construct baseline query_cache_key from query_context with post processing operation
query_context = ChartDataQueryContextSchema().load(payload) query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0] query_object = query_context.queries[0]
cache_key_original = query_context.query_cache_key(query_object) cache_key_original = query_context.query_cache_key(query_object)
# ensure added None post_processing operation doesn't change cache_key # ensure added None post_processing operation doesn't change query_cache_key
payload["queries"][0]["post_processing"].append(None) payload["queries"][0]["post_processing"].append(None)
query_context = ChartDataQueryContextSchema().load(payload) query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0] query_object = query_context.queries[0]