From 103cd3d6f35e9288e317629064bedb6debdf7a69 Mon Sep 17 00:00:00 2001 From: nsivarajan <117266407+nsivarajan@users.noreply.github.com> Date: Fri, 30 Aug 2024 23:12:29 +0530 Subject: [PATCH] feat(GAQ): Add Redis Sentinel Support for Global Async Queries (#29912) Co-authored-by: Sivarajan Narayanan --- superset/async_events/async_query_manager.py | 80 +++++-- superset/async_events/cache_backend.py | 209 ++++++++++++++++++ superset/config.py | 22 ++ .../async_events/api_tests.py | 53 +++-- .../tasks/async_queries_tests.py | 94 +++++++- .../async_events/async_query_manager_tests.py | 37 +++- 6 files changed, 450 insertions(+), 45 deletions(-) create mode 100644 superset/async_events/cache_backend.py diff --git a/superset/async_events/async_query_manager.py b/superset/async_events/async_query_manager.py index 84b20d753..b116c3cfc 100644 --- a/superset/async_events/async_query_manager.py +++ b/superset/async_events/async_query_manager.py @@ -14,20 +14,31 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging import uuid -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union import jwt import redis from flask import Flask, Request, request, Response, session +from flask_caching.backends.base import BaseCache +from superset.async_events.cache_backend import ( + RedisCacheBackend, + RedisSentinelCacheBackend, +) from superset.utils import json from superset.utils.core import get_user_id logger = logging.getLogger(__name__) +class CacheBackendNotInitialized(Exception): + pass + + class AsyncQueryTokenException(Exception): pass @@ -55,13 +66,32 @@ def parse_event(event_data: tuple[str, dict[str, Any]]) -> dict[str, Any]: return {"id": event_id, **json.loads(event_payload)} -def increment_id(redis_id: str) -> str: +def increment_id(entry_id: str) -> str: # redis stream IDs are in this format: '1607477697866-0' try: - prefix, last = redis_id[:-1], int(redis_id[-1]) + prefix, last = entry_id[:-1], int(entry_id[-1]) return prefix + str(last + 1) except Exception: # pylint: disable=broad-except - return redis_id + return entry_id + + +def get_cache_backend( + config: dict[str, Any], +) -> Union[RedisCacheBackend, RedisSentinelCacheBackend, redis.Redis]: # type: ignore + cache_config = config.get("GLOBAL_ASYNC_QUERIES_CACHE_BACKEND", {}) + cache_type = cache_config.get("CACHE_TYPE") + + if cache_type == "RedisCache": + return RedisCacheBackend.from_config(cache_config) + + if cache_type == "RedisSentinelCache": + return RedisSentinelCacheBackend.from_config(cache_config) + + # TODO: Deprecate hardcoded plain Redis code and expand cache backend options. + # Maintain backward compatibility with 'GLOBAL_ASYNC_QUERIES_REDIS_CONFIG' until it is deprecated. + return redis.Redis( + **config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True + ) class AsyncQueryManager: @@ -73,7 +103,7 @@ class AsyncQueryManager: def __init__(self) -> None: super().__init__() - self._redis: redis.Redis # type: ignore + self._cache: Optional[BaseCache] = None self._stream_prefix: str = "" self._stream_limit: Optional[int] self._stream_limit_firehose: Optional[int] @@ -88,10 +118,9 @@ class AsyncQueryManager: def init_app(self, app: Flask) -> None: config = app.config - if ( - config["CACHE_CONFIG"]["CACHE_TYPE"] == "null" - or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null" - ): + cache_type = config.get("CACHE_CONFIG", {}).get("CACHE_TYPE") + data_cache_type = config.get("DATA_CACHE_CONFIG", {}).get("CACHE_TYPE") + if cache_type in [None, "null"] or data_cache_type in [None, "null"]: raise Exception( # pylint: disable=broad-exception-raised """ Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured @@ -99,14 +128,14 @@ class AsyncQueryManager: """ ) + self._cache = get_cache_backend(config) + logger.debug("Using GAQ Cache backend as %s", type(self._cache).__name__) + if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32: raise AsyncQueryTokenException( "Please provide a JWT secret at least 32 bytes long" ) - self._redis = redis.Redis( - **config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True - ) self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"] self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"] self._stream_limit_firehose = config[ @@ -230,14 +259,35 @@ class AsyncQueryManager: def read_events( self, channel: str, last_id: Optional[str] ) -> list[Optional[dict[str, Any]]]: + if not self._cache: + raise CacheBackendNotInitialized("Cache backend not initialized") + stream_name = f"{self._stream_prefix}{channel}" start_id = increment_id(last_id) if last_id else "-" - results = self._redis.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT) + results = self._cache.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT) + # Decode bytes to strings, decode_responses is not supported at RedisCache and RedisSentinelCache + if isinstance(self._cache, (RedisSentinelCacheBackend, RedisCacheBackend)): + decoded_results = [ + ( + event_id.decode("utf-8"), + { + key.decode("utf-8"): value.decode("utf-8") + for key, value in event_data.items() + }, + ) + for event_id, event_data in results + ] + return ( + [] if not decoded_results else list(map(parse_event, decoded_results)) + ) return [] if not results else list(map(parse_event, results)) def update_job( self, job_metadata: dict[str, Any], status: str, **kwargs: Any ) -> None: + if not self._cache: + raise CacheBackendNotInitialized("Cache backend not initialized") + if "channel_id" not in job_metadata: raise AsyncQueryJobException("No channel ID specified") @@ -253,5 +303,5 @@ class AsyncQueryManager: logger.debug("********** logging event data to stream %s", scoped_stream_name) logger.debug(event_data) - self._redis.xadd(scoped_stream_name, event_data, "*", self._stream_limit) - self._redis.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose) + self._cache.xadd(scoped_stream_name, event_data, "*", self._stream_limit) + self._cache.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose) diff --git a/superset/async_events/cache_backend.py b/superset/async_events/cache_backend.py new file mode 100644 index 000000000..15887e47a --- /dev/null +++ b/superset/async_events/cache_backend.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Optional, Tuple + +import redis +from flask_caching.backends.rediscache import RedisCache, RedisSentinelCache +from redis.sentinel import Sentinel + + +class RedisCacheBackend(RedisCache): + MAX_EVENT_COUNT = 100 + + def __init__( # pylint: disable=too-many-arguments + self, + host: str, + port: int, + password: Optional[str] = None, + db: int = 0, + default_timeout: int = 300, + key_prefix: Optional[str] = None, + ssl: bool = False, + ssl_certfile: Optional[str] = None, + ssl_keyfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__( + host=host, + port=port, + password=password, + db=db, + default_timeout=default_timeout, + key_prefix=key_prefix, + **kwargs, + ) + self._cache = redis.Redis( + host=host, + port=port, + password=password, + db=db, + ssl=ssl, + ssl_certfile=ssl_certfile, + ssl_keyfile=ssl_keyfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + **kwargs, + ) + + def xadd( + self, + stream_name: str, + event_data: Dict[str, Any], + event_id: str = "*", + maxlen: Optional[int] = None, + ) -> str: + return self._cache.xadd(stream_name, event_data, event_id, maxlen) + + def xrange( + self, + stream_name: str, + start: str = "-", + end: str = "+", + count: Optional[int] = None, + ) -> List[Any]: + count = count or self.MAX_EVENT_COUNT + return self._cache.xrange(stream_name, start, end, count) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "RedisCacheBackend": + kwargs = { + "host": config.get("CACHE_REDIS_HOST", "localhost"), + "port": config.get("CACHE_REDIS_PORT", 6379), + "db": config.get("CACHE_REDIS_DB", 0), + "password": config.get("CACHE_REDIS_PASSWORD", None), + "key_prefix": config.get("CACHE_KEY_PREFIX", None), + "default_timeout": config.get("CACHE_DEFAULT_TIMEOUT", 300), + "ssl": config.get("CACHE_REDIS_SSL", False), + "ssl_certfile": config.get("CACHE_REDIS_SSL_CERTFILE", None), + "ssl_keyfile": config.get("CACHE_REDIS_SSL_KEYFILE", None), + "ssl_cert_reqs": config.get("CACHE_REDIS_SSL_CERT_REQS", "required"), + "ssl_ca_certs": config.get("CACHE_REDIS_SSL_CA_CERTS", None), + } + return cls(**kwargs) + + +class RedisSentinelCacheBackend(RedisSentinelCache): + MAX_EVENT_COUNT = 100 + + def __init__( # pylint: disable=too-many-arguments + self, + sentinels: List[Tuple[str, int]], + master: str, + password: Optional[str] = None, + sentinel_password: Optional[str] = None, + db: int = 0, + default_timeout: int = 300, + key_prefix: str = "", + ssl: bool = False, + ssl_certfile: Optional[str] = None, + ssl_keyfile: Optional[str] = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: Optional[str] = None, + **kwargs: Any, + ) -> None: + # Sentinel dont directly support SSL + # Initialize Sentinel without SSL parameters + self._sentinel = Sentinel( + sentinels, + sentinel_kwargs={ + "password": sentinel_password, + }, + **{ + k: v + for k, v in kwargs.items() + if k + not in [ + "ssl", + "ssl_certfile", + "ssl_keyfile", + "ssl_cert_reqs", + "ssl_ca_certs", + ] + }, + ) + + # Prepare SSL-related arguments for master_for method + master_kwargs = { + "password": password, + "ssl": ssl, + "ssl_certfile": ssl_certfile if ssl else None, + "ssl_keyfile": ssl_keyfile if ssl else None, + "ssl_cert_reqs": ssl_cert_reqs if ssl else None, + "ssl_ca_certs": ssl_ca_certs if ssl else None, + } + + # If SSL is False, remove all SSL-related keys + # SSL_* are expected only if SSL is True + if not ssl: + master_kwargs = { + k: v for k, v in master_kwargs.items() if not k.startswith("ssl") + } + + # Filter out None values from master_kwargs + master_kwargs = {k: v for k, v in master_kwargs.items() if v is not None} + + # Initialize Redis master connection + self._cache = self._sentinel.master_for(master, **master_kwargs) + + # Call the parent class constructor + super().__init__( + host=None, + port=None, + password=password, + db=db, + default_timeout=default_timeout, + key_prefix=key_prefix, + **kwargs, + ) + + def xadd( + self, + stream_name: str, + event_data: Dict[str, Any], + event_id: str = "*", + maxlen: Optional[int] = None, + ) -> str: + return self._cache.xadd(stream_name, event_data, event_id, maxlen) + + def xrange( + self, + stream_name: str, + start: str = "-", + end: str = "+", + count: Optional[int] = None, + ) -> List[Any]: + count = count or self.MAX_EVENT_COUNT + return self._cache.xrange(stream_name, start, end, count) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "RedisSentinelCacheBackend": + kwargs = { + "sentinels": config.get("CACHE_REDIS_SENTINELS", [("127.0.0.1", 26379)]), + "master": config.get("CACHE_REDIS_SENTINEL_MASTER", "mymaster"), + "password": config.get("CACHE_REDIS_PASSWORD", None), + "sentinel_password": config.get("CACHE_REDIS_SENTINEL_PASSWORD", None), + "key_prefix": config.get("CACHE_KEY_PREFIX", ""), + "db": config.get("CACHE_REDIS_DB", 0), + "ssl": config.get("CACHE_REDIS_SSL", False), + "ssl_certfile": config.get("CACHE_REDIS_SSL_CERTFILE", None), + "ssl_keyfile": config.get("CACHE_REDIS_SSL_KEYFILE", None), + "ssl_cert_reqs": config.get("CACHE_REDIS_SSL_CERT_REQS", "required"), + "ssl_ca_certs": config.get("CACHE_REDIS_SSL_CA_CERTS", None), + } + return cls(**kwargs) diff --git a/superset/config.py b/superset/config.py index 5b30397a0..670c51893 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1690,6 +1690,28 @@ GLOBAL_ASYNC_QUERIES_POLLING_DELAY = int( ) GLOBAL_ASYNC_QUERIES_WEBSOCKET_URL = "ws://127.0.0.1:8080/" +# Global async queries cache backend configuration options: +# - Set 'CACHE_TYPE' to 'RedisCache' for RedisCacheBackend. +# - Set 'CACHE_TYPE' to 'RedisSentinelCache' for RedisSentinelCacheBackend. +# - Set 'CACHE_TYPE' to 'None' to fall back on 'GLOBAL_ASYNC_QUERIES_REDIS_CONFIG'. +GLOBAL_ASYNC_QUERIES_CACHE_BACKEND = { + "CACHE_TYPE": "RedisCache", + "CACHE_REDIS_HOST": "localhost", + "CACHE_REDIS_PORT": 6379, + "CACHE_REDIS_USER": "", + "CACHE_REDIS_PASSWORD": "", + "CACHE_REDIS_DB": 0, + "CACHE_DEFAULT_TIMEOUT": 300, + "CACHE_REDIS_SENTINELS": [("localhost", 26379)], + "CACHE_REDIS_SENTINEL_MASTER": "mymaster", + "CACHE_REDIS_SENTINEL_PASSWORD": None, + "CACHE_REDIS_SSL": False, # True or False + "CACHE_REDIS_SSL_CERTFILE": None, + "CACHE_REDIS_SSL_KEYFILE": None, + "CACHE_REDIS_SSL_CERT_REQS": "required", + "CACHE_REDIS_SSL_CA_CERTS": None, +} + # Embedded config options GUEST_ROLE_NAME = "Public" GUEST_TOKEN_JWT_SECRET = "test-guest-secret-change-me" diff --git a/tests/integration_tests/async_events/api_tests.py b/tests/integration_tests/async_events/api_tests.py index 66aef25c2..5a8189f9a 100644 --- a/tests/integration_tests/async_events/api_tests.py +++ b/tests/integration_tests/async_events/api_tests.py @@ -14,9 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional +from typing import Any, Optional, Type from unittest import mock +import redis + +from superset.async_events.cache_backend import ( + RedisCacheBackend, + RedisSentinelCacheBackend, +) from superset.extensions import async_query_manager from superset.utils import json from tests.integration_tests.base_tests import SupersetTestCase @@ -32,12 +38,21 @@ class TestAsyncEventApi(SupersetTestCase): uri = f"{base_uri}?last_id={last_id}" if last_id else base_uri return self.client.get(uri) - @mock.patch("uuid.uuid4", return_value=UUID) - def test_events(self, mock_uuid4): + def run_test_with_cache_backend(self, cache_backend_cls: Type[Any], test_func): app._got_first_request = False async_query_manager.init_app(app) + + # Create a mock cache backend instance + mock_cache = mock.Mock(spec=cache_backend_cls) + + # Set the mock cache instance + async_query_manager._cache = mock_cache + self.login(ADMIN_USERNAME) - with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange: + test_func(mock_cache) + + def _test_events_logic(self, mock_cache): + with mock.patch.object(mock_cache, "xrange") as mock_xrange: rv = self.fetch_events() response = json.loads(rv.data.decode("utf-8")) @@ -46,12 +61,8 @@ class TestAsyncEventApi(SupersetTestCase): mock_xrange.assert_called_with(channel_id, "-", "+", 100) self.assertEqual(response, {"result": []}) - @mock.patch("uuid.uuid4", return_value=UUID) - def test_events_last_id(self, mock_uuid4): - app._got_first_request = False - async_query_manager.init_app(app) - self.login(ADMIN_USERNAME) - with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange: + def _test_events_last_id_logic(self, mock_cache): + with mock.patch.object(mock_cache, "xrange") as mock_xrange: rv = self.fetch_events("1607471525180-0") response = json.loads(rv.data.decode("utf-8")) @@ -60,12 +71,8 @@ class TestAsyncEventApi(SupersetTestCase): mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100) self.assertEqual(response, {"result": []}) - @mock.patch("uuid.uuid4", return_value=UUID) - def test_events_results(self, mock_uuid4): - app._got_first_request = False - async_query_manager.init_app(app) - self.login(ADMIN_USERNAME) - with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange: + def _test_events_results_logic(self, mock_cache): + with mock.patch.object(mock_cache, "xrange") as mock_xrange: mock_xrange.return_value = [ ( "1607477697866-0", @@ -110,6 +117,20 @@ class TestAsyncEventApi(SupersetTestCase): } self.assertEqual(response, expected) + @mock.patch("uuid.uuid4", return_value=UUID) + def test_events_redis_cache_backend(self, mock_uuid4): + self.run_test_with_cache_backend(RedisCacheBackend, self._test_events_logic) + + @mock.patch("uuid.uuid4", return_value=UUID) + def test_events_redis_sentinel_cache_backend(self, mock_uuid4): + self.run_test_with_cache_backend( + RedisSentinelCacheBackend, self._test_events_logic + ) + + @mock.patch("uuid.uuid4", return_value=UUID) + def test_events_redis(self, mock_uuid4): + self.run_test_with_cache_backend(redis.Redis, self._test_events_logic) + def test_events_no_login(self): app._got_first_request = False async_query_manager.init_app(app) diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 8abfe691d..01b759c35 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -20,8 +20,14 @@ from unittest import mock from uuid import uuid4 import pytest +import redis from celery.exceptions import SoftTimeLimitExceeded +from parameterized import parameterized +from superset.async_events.cache_backend import ( + RedisCacheBackend, + RedisSentinelCacheBackend, +) from superset.commands.chart.data.get_data_command import ChartDataCommand from superset.commands.chart.exceptions import ChartDataQueryFailedError from superset.exceptions import SupersetException @@ -38,17 +44,29 @@ from tests.integration_tests.fixtures.tags import ( from tests.integration_tests.test_app import app +@pytest.mark.usefixtures( + "load_birth_names_data", "load_birth_names_dashboard_with_slices" +) class TestAsyncQueries(SupersetTestCase): - @pytest.mark.usefixtures( - "load_birth_names_data", "load_birth_names_dashboard_with_slices" + @parameterized.expand( + [ + ("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)), + ("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)), + ("redis.Redis", mock.Mock(spec=redis.Redis)), + ] ) - @mock.patch.object(async_query_manager, "update_job") @mock.patch("superset.tasks.async_queries.set_form_data") - def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job): + @mock.patch.object(async_query_manager, "update_job") + def test_load_chart_data_into_cache( + self, cache_type, cache_backend, mock_update_job, mock_set_form_data + ): from superset.tasks.async_queries import load_chart_data_into_cache app._got_first_request = False + + async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend) async_query_manager.init_app(app) + query_context = get_query_context("birth_names") user = security_manager.find_user("gamma") job_metadata = { @@ -60,20 +78,33 @@ class TestAsyncQueries(SupersetTestCase): } load_chart_data_into_cache(job_metadata, query_context) + mock_set_form_data.assert_called_once_with(query_context) mock_update_job.assert_called_once_with( job_metadata, "done", result_url=mock.ANY ) + @parameterized.expand( + [ + ("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)), + ("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)), + ("redis.Redis", mock.Mock(spec=redis.Redis)), + ] + ) @mock.patch.object( ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo") ) @mock.patch.object(async_query_manager, "update_job") - def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command): + def test_load_chart_data_into_cache_error( + self, cache_type, cache_backend, mock_update_job, mock_run_command + ): from superset.tasks.async_queries import load_chart_data_into_cache app._got_first_request = False + + async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend) async_query_manager.init_app(app) + query_context = get_query_context("birth_names") user = security_manager.find_user("gamma") job_metadata = { @@ -90,15 +121,25 @@ class TestAsyncQueries(SupersetTestCase): errors = [{"message": "Error: foo"}] mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors) + @parameterized.expand( + [ + ("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)), + ("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)), + ("redis.Redis", mock.Mock(spec=redis.Redis)), + ] + ) @mock.patch.object(ChartDataCommand, "run") @mock.patch.object(async_query_manager, "update_job") def test_soft_timeout_load_chart_data_into_cache( - self, mock_update_job, mock_run_command + self, cache_type, cache_backend, mock_update_job, mock_run_command ): from superset.tasks.async_queries import load_chart_data_into_cache app._got_first_request = False + + async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend) async_query_manager.init_app(app) + user = security_manager.find_user("gamma") form_data = {} job_metadata = { @@ -118,13 +159,25 @@ class TestAsyncQueries(SupersetTestCase): load_chart_data_into_cache(job_metadata, form_data) set_form_data.assert_called_once_with(form_data, "error", errors=errors) + @parameterized.expand( + [ + ("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)), + ("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)), + ("redis.Redis", mock.Mock(spec=redis.Redis)), + ] + ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") - def test_load_explore_json_into_cache(self, mock_update_job): + def test_load_explore_json_into_cache( + self, cache_type, cache_backend, mock_update_job + ): from superset.tasks.async_queries import load_explore_json_into_cache app._got_first_request = False + + async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend) async_query_manager.init_app(app) + table = self.get_table(name="birth_names") user = security_manager.find_user("gamma") form_data = { @@ -146,19 +199,30 @@ class TestAsyncQueries(SupersetTestCase): } load_explore_json_into_cache(job_metadata, form_data) + mock_update_job.assert_called_once_with( job_metadata, "done", result_url=mock.ANY ) + @parameterized.expand( + [ + ("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)), + ("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)), + ("redis.Redis", mock.Mock(spec=redis.Redis)), + ] + ) @mock.patch.object(async_query_manager, "update_job") @mock.patch("superset.tasks.async_queries.set_form_data") def test_load_explore_json_into_cache_error( - self, mock_set_form_data, mock_update_job + self, cache_type, cache_backend, mock_set_form_data, mock_update_job ): from superset.tasks.async_queries import load_explore_json_into_cache app._got_first_request = False + + async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend) async_query_manager.init_app(app) + user = security_manager.find_user("gamma") form_data = {} job_metadata = { @@ -176,15 +240,25 @@ class TestAsyncQueries(SupersetTestCase): errors = ["The dataset associated with this chart no longer exists"] mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors) + @parameterized.expand( + [ + ("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)), + ("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)), + ("redis.Redis", mock.Mock(spec=redis.Redis)), + ] + ) @mock.patch.object(ChartDataCommand, "run") @mock.patch.object(async_query_manager, "update_job") def test_soft_timeout_load_explore_json_into_cache( - self, mock_update_job, mock_run_command + self, cache_type, cache_backend, mock_update_job, mock_run_command ): from superset.tasks.async_queries import load_explore_json_into_cache app._got_first_request = False + + async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend) async_query_manager.init_app(app) + user = security_manager.find_user("gamma") form_data = {} job_metadata = { @@ -194,7 +268,7 @@ class TestAsyncQueries(SupersetTestCase): "status": "pending", "errors": [], } - errors = ["A timeout occurred while loading explore json, error"] + errors = ["A timeout occurred while loading explore JSON data"] with pytest.raises(SoftTimeLimitExceeded): with mock.patch( diff --git a/tests/unit_tests/async_events/async_query_manager_tests.py b/tests/unit_tests/async_events/async_query_manager_tests.py index 85ea11420..2ccae644a 100644 --- a/tests/unit_tests/async_events/async_query_manager_tests.py +++ b/tests/unit_tests/async_events/async_query_manager_tests.py @@ -17,15 +17,20 @@ from unittest import mock from unittest.mock import ANY, Mock +import redis from flask import g from jwt import encode -from pytest import fixture, raises +from pytest import fixture, mark, raises from superset import security_manager from superset.async_events.async_query_manager import ( AsyncQueryManager, AsyncQueryTokenException, ) +from superset.async_events.cache_backend import ( + RedisCacheBackend, + RedisSentinelCacheBackend, +) JWT_TOKEN_SECRET = "some_secret" JWT_TOKEN_COOKIE_NAME = "superset_async_jwt" @@ -36,7 +41,6 @@ def async_query_manager(): query_manager = AsyncQueryManager() query_manager._jwt_secret = JWT_TOKEN_SECRET query_manager._jwt_cookie_name = JWT_TOKEN_COOKIE_NAME - return query_manager @@ -75,12 +79,24 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager): async_query_manager.parse_channel_id_from_request(request) +@mark.parametrize( + "cache_type, cache_backend", + [ + ("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)), + ("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)), + ("redis.Redis", mock.Mock(spec=redis.Redis)), + ], +) @mock.patch("superset.is_feature_enabled") def test_submit_chart_data_job_as_guest_user( - is_feature_enabled_mock, async_query_manager + is_feature_enabled_mock, async_query_manager, cache_type, cache_backend ): is_feature_enabled_mock.return_value = True set_current_as_guest_user() + + # Mock the get_cache_backend method to return the current cache backend + async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend) + job_mock = Mock() async_query_manager._load_chart_data_into_cache_job = job_mock job_meta = async_query_manager.submit_chart_data_job( @@ -105,14 +121,27 @@ def test_submit_chart_data_job_as_guest_user( ) assert "guest_token" not in job_meta + job_mock.reset_mock() # Reset the mock for the next iteration +@mark.parametrize( + "cache_type, cache_backend", + [ + ("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)), + ("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)), + ("redis.Redis", mock.Mock(spec=redis.Redis)), + ], +) @mock.patch("superset.is_feature_enabled") def test_submit_explore_json_job_as_guest_user( - is_feature_enabled_mock, async_query_manager + is_feature_enabled_mock, async_query_manager, cache_type, cache_backend ): is_feature_enabled_mock.return_value = True set_current_as_guest_user() + + # Mock the get_cache_backend method to return the current cache backend + async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend) + job_mock = Mock() async_query_manager._load_explore_json_into_cache_job = job_mock job_meta = async_query_manager.submit_explore_json_job(