From d7cbd53fce2c6ec2301d99bfda0b1969e5ec36e5 Mon Sep 17 00:00:00 2001 From: Rob DiCiuccio Date: Wed, 27 Jan 2021 10:16:57 -0800 Subject: [PATCH] fix(async queries): Remove "force" param on cached data retrieval (#12103) * Async queries: remove force cache param on data retrieval * Assert equal query_object cache keys * Decouple etag_cache from permission checks * Fix query_context test * Use marshmallow EnumField for validation --- setup.cfg | 2 +- superset/charts/schemas.py | 14 +++++------ superset/common/query_context.py | 5 ++-- superset/utils/cache.py | 7 +----- superset/views/core.py | 9 ++++--- superset/views/utils.py | 18 ++++++++++++++ superset/viz.py | 16 +++--------- tests/query_context_tests.py | 42 +++++++++++++++++++++++++++----- 8 files changed, 74 insertions(+), 39 deletions(-) diff --git a/setup.cfg b/setup.cfg index c4d4140cd..9dd35f5fe 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index f85ad00a4..d62707e0b 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -19,11 +19,14 @@ from typing import Any, Dict from flask_babel import gettext as _ from marshmallow import EXCLUDE, fields, post_load, Schema, validate from marshmallow.validate import Length, Range +from marshmallow_enum import EnumField from superset.common.query_context import QueryContext from superset.utils import schema as utils from superset.utils.core import ( AnnotationType, + ChartDataResultFormat, + ChartDataResultType, FilterOperator, PostProcessingBoxplotWhiskerType, PostProcessingContributionOrientation, @@ -1012,14 +1015,9 @@ class ChartDataQueryContextSchema(Schema): description="Should the queries be forced to load from the source. " "Default: `false`", ) - result_type = fields.String( - description="Type of results to return", - validate=validate.OneOf(choices=("full", "query", "results", "samples")), - ) - result_format = fields.String( - description="Format of result payload", - validate=validate.OneOf(choices=("json", "csv")), - ) + + result_type = EnumField(ChartDataResultType, by_value=True) + result_format = EnumField(ChartDataResultFormat, by_value=True) # pylint: disable=no-self-use,unused-argument @post_load diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 041100724..3c2081364 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -85,9 +85,8 @@ class QueryContext: self.cache_values = { "datasource": datasource, "queries": queries, - "force": force, - "result_type": result_type, - "result_format": result_format, + "result_type": self.result_type, + "result_format": self.result_format, } def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]: diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 729c31686..66da4688b 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -125,9 +125,7 @@ def memoized_func( def etag_cache( - check_perms: Callable[..., Any], - cache: Cache = cache_manager.cache, - max_age: Optional[Union[int, float]] = None, + cache: Cache = cache_manager.cache, max_age: Optional[Union[int, float]] = None, ) -> Callable[..., Any]: """ A decorator for caching views and handling etag conditional requests. @@ -147,9 +145,6 @@ def etag_cache( def decorator(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin: - # check if the user can access the resource - check_perms(*args, **kwargs) - # for POST requests we can't set cache headers, use the response # cache nor use conditional requests; this will still use the # dataframe cache in `superset/viz.py`, though. diff --git a/superset/views/core.py b/superset/views/core.py index 087ef422b..47df3e865 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -127,6 +127,7 @@ from superset.views.utils import ( bootstrap_user_data, check_datasource_perms, check_explore_cache_perms, + check_resource_permissions, check_slice_perms, get_cta_schema_name, get_dashboard_extra_filters, @@ -458,7 +459,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @api @has_access_api @expose("/slice_json/") - @etag_cache(check_perms=check_slice_perms) + @etag_cache() + @check_resource_permissions(check_slice_perms) def slice_json(self, slice_id: int) -> FlaskResponse: form_data, slc = get_form_data(slice_id, use_slice_data=True) if not slc: @@ -510,7 +512,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @handle_api_exception @permission_name("explore_json") @expose("/explore_json/data/", methods=["GET"]) - @etag_cache(check_perms=check_explore_cache_perms) + @check_resource_permissions(check_explore_cache_perms) def explore_json_data(self, cache_key: str) -> FlaskResponse: """Serves cached result data for async explore_json calls @@ -554,7 +556,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods methods=EXPLORE_JSON_METHODS, ) @expose("/explore_json/", methods=EXPLORE_JSON_METHODS) - @etag_cache(check_perms=check_datasource_perms) + @etag_cache() + @check_resource_permissions(check_datasource_perms) def explore_json( self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None ) -> FlaskResponse: diff --git a/superset/views/utils.py b/superset/views/utils.py index f37d055d8..5f9ec10e8 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -17,6 +17,7 @@ import logging from collections import defaultdict from datetime import date +from functools import wraps from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union from urllib import parse @@ -437,6 +438,23 @@ def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool: return obj and user in obj.owners +def check_resource_permissions(check_perms: Callable[..., Any],) -> Callable[..., Any]: + """ + A decorator for checking permissions on a request using the passed-in function. + """ + + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> None: + # check if the user can access the resource + check_perms(*args, **kwargs) + return f(*args, **kwargs) + + return wrapper + + return decorator + + def check_explore_cache_perms(_self: Any, cache_key: str) -> None: """ Loads async explore_json request data from cache and performs access check diff --git a/superset/viz.py b/superset/viz.py index fa6588733..4250022c6 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -143,13 +143,6 @@ class BaseViz: self._force_cached = force_cached self.from_dttm: Optional[datetime] = None self.to_dttm: Optional[datetime] = None - - # Keeping track of whether some data came from cache - # this is useful to trigger the when - # in the cases where visualization have many queries - # (FilterBox for instance) - self._any_cache_key: Optional[str] = None - self._any_cached_dttm: Optional[str] = None self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = [] self.process_metrics() @@ -496,6 +489,7 @@ class BaseViz: if not query_obj: query_obj = self.query_obj() cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None + cache_value = None logger.info("Cache key: {}".format(cache_key)) is_loaded = False stacktrace = None @@ -507,8 +501,6 @@ class BaseViz: try: df = cache_value["df"] self.query = cache_value["query"] - self._any_cached_dttm = cache_value["dttm"] - self._any_cache_key = cache_key self.status = utils.QueryStatus.SUCCESS is_loaded = True stats_logger.incr("loaded_from_cache") @@ -583,13 +575,13 @@ class BaseViz: self.datasource.uid, ) return { - "cache_key": self._any_cache_key, - "cached_dttm": self._any_cached_dttm, + "cache_key": cache_key, + "cached_dttm": cache_value["dttm"] if cache_value is not None else None, "cache_timeout": self.cache_timeout, "df": df, "errors": self.errors, "form_data": self.form_data, - "is_cached": self._any_cache_key is not None, + "is_cached": cache_value is not None, "query": self.query, "from_dttm": self.from_dttm, "to_dttm": self.to_dttm, diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index 220190047..99bc9421e 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -19,6 +19,8 @@ import pytest from superset import db from superset.charts.schemas import ChartDataQueryContextSchema from superset.connectors.connector_registry import ConnectorRegistry +from superset.extensions import cache_manager +from superset.models.cache import CacheKey from superset.utils.core import ( AdhocMetricExpressionType, ChartDataResultFormat, @@ -68,11 +70,39 @@ class TestQueryContext(SupersetTestCase): self.assertEqual(post_proc["operation"], payload_post_proc["operation"]) self.assertEqual(post_proc["options"], payload_post_proc["options"]) - def test_cache_key_changes_when_datasource_is_updated(self): + def test_cache(self): + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id) + payload["force"] = True + + query_context = ChartDataQueryContextSchema().load(payload) + query_object = query_context.queries[0] + query_cache_key = query_context.query_cache_key(query_object) + + response = query_context.get_payload(cache_query_context=True) + cache_key = response["cache_key"] + assert cache_key is not None + + cached = cache_manager.cache.get(cache_key) + assert cached is not None + + rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"]) + rehydrated_qo = rehydrated_qc.queries[0] + rehydrated_query_cache_key = rehydrated_qc.query_cache_key(rehydrated_qo) + + self.assertEqual(rehydrated_qc.datasource, query_context.datasource) + self.assertEqual(len(rehydrated_qc.queries), 1) + self.assertEqual(query_cache_key, rehydrated_query_cache_key) + self.assertEqual(rehydrated_qc.result_type, query_context.result_type) + self.assertEqual(rehydrated_qc.result_format, query_context.result_format) + self.assertFalse(rehydrated_qc.force) + + def test_query_cache_key_changes_when_datasource_is_updated(self): self.login(username="admin") payload = get_query_context("birth_names") - # construct baseline cache_key + # construct baseline query_cache_key query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_original = query_context.query_cache_key(query_object) @@ -89,7 +119,7 @@ class TestQueryContext(SupersetTestCase): datasource.description = description_original db.session.commit() - # create new QueryContext with unchanged attributes and extract new cache_key + # create new QueryContext with unchanged attributes, extract new query_cache_key query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_new = query_context.query_cache_key(query_object) @@ -97,16 +127,16 @@ class TestQueryContext(SupersetTestCase): # the new cache_key should be different due to updated datasource self.assertNotEqual(cache_key_original, cache_key_new) - def test_cache_key_changes_when_post_processing_is_updated(self): + def test_query_cache_key_changes_when_post_processing_is_updated(self): self.login(username="admin") payload = get_query_context("birth_names", add_postprocessing_operations=True) - # construct baseline cache_key from query_context with post processing operation + # construct baseline query_cache_key from query_context with post processing operation query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] cache_key_original = query_context.query_cache_key(query_object) - # ensure added None post_processing operation doesn't change cache_key + # ensure added None post_processing operation doesn't change query_cache_key payload["queries"][0]["post_processing"].append(None) query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0]