feat(GAQ): Add Redis Sentinel Support for Global Async Queries (#29912)

Co-authored-by: Sivarajan Narayanan <narayanan_sivarajan@apple.com>
This commit is contained in:
nsivarajan 2024-08-30 23:12:29 +05:30 committed by GitHub
parent cd6b8b2f6d
commit 103cd3d6f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 450 additions and 45 deletions

View File

@ -14,20 +14,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
import uuid
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Union
import jwt
import redis
from flask import Flask, Request, request, Response, session
from flask_caching.backends.base import BaseCache
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
from superset.utils import json
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
class CacheBackendNotInitialized(Exception):
pass
class AsyncQueryTokenException(Exception):
pass
@ -55,13 +66,32 @@ def parse_event(event_data: tuple[str, dict[str, Any]]) -> dict[str, Any]:
return {"id": event_id, **json.loads(event_payload)}
def increment_id(redis_id: str) -> str:
def increment_id(entry_id: str) -> str:
# redis stream IDs are in this format: '1607477697866-0'
try:
prefix, last = redis_id[:-1], int(redis_id[-1])
prefix, last = entry_id[:-1], int(entry_id[-1])
return prefix + str(last + 1)
except Exception: # pylint: disable=broad-except
return redis_id
return entry_id
def get_cache_backend(
config: dict[str, Any],
) -> Union[RedisCacheBackend, RedisSentinelCacheBackend, redis.Redis]: # type: ignore
cache_config = config.get("GLOBAL_ASYNC_QUERIES_CACHE_BACKEND", {})
cache_type = cache_config.get("CACHE_TYPE")
if cache_type == "RedisCache":
return RedisCacheBackend.from_config(cache_config)
if cache_type == "RedisSentinelCache":
return RedisSentinelCacheBackend.from_config(cache_config)
# TODO: Deprecate hardcoded plain Redis code and expand cache backend options.
# Maintain backward compatibility with 'GLOBAL_ASYNC_QUERIES_REDIS_CONFIG' until it is deprecated.
return redis.Redis(
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
class AsyncQueryManager:
@ -73,7 +103,7 @@ class AsyncQueryManager:
def __init__(self) -> None:
super().__init__()
self._redis: redis.Redis # type: ignore
self._cache: Optional[BaseCache] = None
self._stream_prefix: str = ""
self._stream_limit: Optional[int]
self._stream_limit_firehose: Optional[int]
@ -88,10 +118,9 @@ class AsyncQueryManager:
def init_app(self, app: Flask) -> None:
config = app.config
if (
config["CACHE_CONFIG"]["CACHE_TYPE"] == "null"
or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null"
):
cache_type = config.get("CACHE_CONFIG", {}).get("CACHE_TYPE")
data_cache_type = config.get("DATA_CACHE_CONFIG", {}).get("CACHE_TYPE")
if cache_type in [None, "null"] or data_cache_type in [None, "null"]:
raise Exception( # pylint: disable=broad-exception-raised
"""
Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured
@ -99,14 +128,14 @@ class AsyncQueryManager:
"""
)
self._cache = get_cache_backend(config)
logger.debug("Using GAQ Cache backend as %s", type(self._cache).__name__)
if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32:
raise AsyncQueryTokenException(
"Please provide a JWT secret at least 32 bytes long"
)
self._redis = redis.Redis(
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"]
self._stream_limit_firehose = config[
@ -230,14 +259,35 @@ class AsyncQueryManager:
def read_events(
self, channel: str, last_id: Optional[str]
) -> list[Optional[dict[str, Any]]]:
if not self._cache:
raise CacheBackendNotInitialized("Cache backend not initialized")
stream_name = f"{self._stream_prefix}{channel}"
start_id = increment_id(last_id) if last_id else "-"
results = self._redis.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT)
results = self._cache.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT)
# Decode bytes to strings, decode_responses is not supported at RedisCache and RedisSentinelCache
if isinstance(self._cache, (RedisSentinelCacheBackend, RedisCacheBackend)):
decoded_results = [
(
event_id.decode("utf-8"),
{
key.decode("utf-8"): value.decode("utf-8")
for key, value in event_data.items()
},
)
for event_id, event_data in results
]
return (
[] if not decoded_results else list(map(parse_event, decoded_results))
)
return [] if not results else list(map(parse_event, results))
def update_job(
self, job_metadata: dict[str, Any], status: str, **kwargs: Any
) -> None:
if not self._cache:
raise CacheBackendNotInitialized("Cache backend not initialized")
if "channel_id" not in job_metadata:
raise AsyncQueryJobException("No channel ID specified")
@ -253,5 +303,5 @@ class AsyncQueryManager:
logger.debug("********** logging event data to stream %s", scoped_stream_name)
logger.debug(event_data)
self._redis.xadd(scoped_stream_name, event_data, "*", self._stream_limit)
self._redis.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose)
self._cache.xadd(scoped_stream_name, event_data, "*", self._stream_limit)
self._cache.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose)

View File

@ -0,0 +1,209 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional, Tuple
import redis
from flask_caching.backends.rediscache import RedisCache, RedisSentinelCache
from redis.sentinel import Sentinel
class RedisCacheBackend(RedisCache):
MAX_EVENT_COUNT = 100
def __init__( # pylint: disable=too-many-arguments
self,
host: str,
port: int,
password: Optional[str] = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: Optional[str] = None,
ssl: bool = False,
ssl_certfile: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(
host=host,
port=port,
password=password,
db=db,
default_timeout=default_timeout,
key_prefix=key_prefix,
**kwargs,
)
self._cache = redis.Redis(
host=host,
port=port,
password=password,
db=db,
ssl=ssl,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
**kwargs,
)
def xadd(
self,
stream_name: str,
event_data: Dict[str, Any],
event_id: str = "*",
maxlen: Optional[int] = None,
) -> str:
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
def xrange(
self,
stream_name: str,
start: str = "-",
end: str = "+",
count: Optional[int] = None,
) -> List[Any]:
count = count or self.MAX_EVENT_COUNT
return self._cache.xrange(stream_name, start, end, count)
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "RedisCacheBackend":
kwargs = {
"host": config.get("CACHE_REDIS_HOST", "localhost"),
"port": config.get("CACHE_REDIS_PORT", 6379),
"db": config.get("CACHE_REDIS_DB", 0),
"password": config.get("CACHE_REDIS_PASSWORD", None),
"key_prefix": config.get("CACHE_KEY_PREFIX", None),
"default_timeout": config.get("CACHE_DEFAULT_TIMEOUT", 300),
"ssl": config.get("CACHE_REDIS_SSL", False),
"ssl_certfile": config.get("CACHE_REDIS_SSL_CERTFILE", None),
"ssl_keyfile": config.get("CACHE_REDIS_SSL_KEYFILE", None),
"ssl_cert_reqs": config.get("CACHE_REDIS_SSL_CERT_REQS", "required"),
"ssl_ca_certs": config.get("CACHE_REDIS_SSL_CA_CERTS", None),
}
return cls(**kwargs)
class RedisSentinelCacheBackend(RedisSentinelCache):
MAX_EVENT_COUNT = 100
def __init__( # pylint: disable=too-many-arguments
self,
sentinels: List[Tuple[str, int]],
master: str,
password: Optional[str] = None,
sentinel_password: Optional[str] = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: str = "",
ssl: bool = False,
ssl_certfile: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
**kwargs: Any,
) -> None:
# Sentinel dont directly support SSL
# Initialize Sentinel without SSL parameters
self._sentinel = Sentinel(
sentinels,
sentinel_kwargs={
"password": sentinel_password,
},
**{
k: v
for k, v in kwargs.items()
if k
not in [
"ssl",
"ssl_certfile",
"ssl_keyfile",
"ssl_cert_reqs",
"ssl_ca_certs",
]
},
)
# Prepare SSL-related arguments for master_for method
master_kwargs = {
"password": password,
"ssl": ssl,
"ssl_certfile": ssl_certfile if ssl else None,
"ssl_keyfile": ssl_keyfile if ssl else None,
"ssl_cert_reqs": ssl_cert_reqs if ssl else None,
"ssl_ca_certs": ssl_ca_certs if ssl else None,
}
# If SSL is False, remove all SSL-related keys
# SSL_* are expected only if SSL is True
if not ssl:
master_kwargs = {
k: v for k, v in master_kwargs.items() if not k.startswith("ssl")
}
# Filter out None values from master_kwargs
master_kwargs = {k: v for k, v in master_kwargs.items() if v is not None}
# Initialize Redis master connection
self._cache = self._sentinel.master_for(master, **master_kwargs)
# Call the parent class constructor
super().__init__(
host=None,
port=None,
password=password,
db=db,
default_timeout=default_timeout,
key_prefix=key_prefix,
**kwargs,
)
def xadd(
self,
stream_name: str,
event_data: Dict[str, Any],
event_id: str = "*",
maxlen: Optional[int] = None,
) -> str:
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
def xrange(
self,
stream_name: str,
start: str = "-",
end: str = "+",
count: Optional[int] = None,
) -> List[Any]:
count = count or self.MAX_EVENT_COUNT
return self._cache.xrange(stream_name, start, end, count)
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "RedisSentinelCacheBackend":
kwargs = {
"sentinels": config.get("CACHE_REDIS_SENTINELS", [("127.0.0.1", 26379)]),
"master": config.get("CACHE_REDIS_SENTINEL_MASTER", "mymaster"),
"password": config.get("CACHE_REDIS_PASSWORD", None),
"sentinel_password": config.get("CACHE_REDIS_SENTINEL_PASSWORD", None),
"key_prefix": config.get("CACHE_KEY_PREFIX", ""),
"db": config.get("CACHE_REDIS_DB", 0),
"ssl": config.get("CACHE_REDIS_SSL", False),
"ssl_certfile": config.get("CACHE_REDIS_SSL_CERTFILE", None),
"ssl_keyfile": config.get("CACHE_REDIS_SSL_KEYFILE", None),
"ssl_cert_reqs": config.get("CACHE_REDIS_SSL_CERT_REQS", "required"),
"ssl_ca_certs": config.get("CACHE_REDIS_SSL_CA_CERTS", None),
}
return cls(**kwargs)

View File

@ -1690,6 +1690,28 @@ GLOBAL_ASYNC_QUERIES_POLLING_DELAY = int(
)
GLOBAL_ASYNC_QUERIES_WEBSOCKET_URL = "ws://127.0.0.1:8080/"
# Global async queries cache backend configuration options:
# - Set 'CACHE_TYPE' to 'RedisCache' for RedisCacheBackend.
# - Set 'CACHE_TYPE' to 'RedisSentinelCache' for RedisSentinelCacheBackend.
# - Set 'CACHE_TYPE' to 'None' to fall back on 'GLOBAL_ASYNC_QUERIES_REDIS_CONFIG'.
GLOBAL_ASYNC_QUERIES_CACHE_BACKEND = {
"CACHE_TYPE": "RedisCache",
"CACHE_REDIS_HOST": "localhost",
"CACHE_REDIS_PORT": 6379,
"CACHE_REDIS_USER": "",
"CACHE_REDIS_PASSWORD": "",
"CACHE_REDIS_DB": 0,
"CACHE_DEFAULT_TIMEOUT": 300,
"CACHE_REDIS_SENTINELS": [("localhost", 26379)],
"CACHE_REDIS_SENTINEL_MASTER": "mymaster",
"CACHE_REDIS_SENTINEL_PASSWORD": None,
"CACHE_REDIS_SSL": False, # True or False
"CACHE_REDIS_SSL_CERTFILE": None,
"CACHE_REDIS_SSL_KEYFILE": None,
"CACHE_REDIS_SSL_CERT_REQS": "required",
"CACHE_REDIS_SSL_CA_CERTS": None,
}
# Embedded config options
GUEST_ROLE_NAME = "Public"
GUEST_TOKEN_JWT_SECRET = "test-guest-secret-change-me"

View File

@ -14,9 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional
from typing import Any, Optional, Type
from unittest import mock
import redis
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
from superset.extensions import async_query_manager
from superset.utils import json
from tests.integration_tests.base_tests import SupersetTestCase
@ -32,12 +38,21 @@ class TestAsyncEventApi(SupersetTestCase):
uri = f"{base_uri}?last_id={last_id}" if last_id else base_uri
return self.client.get(uri)
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events(self, mock_uuid4):
def run_test_with_cache_backend(self, cache_backend_cls: Type[Any], test_func):
app._got_first_request = False
async_query_manager.init_app(app)
# Create a mock cache backend instance
mock_cache = mock.Mock(spec=cache_backend_cls)
# Set the mock cache instance
async_query_manager._cache = mock_cache
self.login(ADMIN_USERNAME)
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
test_func(mock_cache)
def _test_events_logic(self, mock_cache):
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
rv = self.fetch_events()
response = json.loads(rv.data.decode("utf-8"))
@ -46,12 +61,8 @@ class TestAsyncEventApi(SupersetTestCase):
mock_xrange.assert_called_with(channel_id, "-", "+", 100)
self.assertEqual(response, {"result": []})
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events_last_id(self, mock_uuid4):
app._got_first_request = False
async_query_manager.init_app(app)
self.login(ADMIN_USERNAME)
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
def _test_events_last_id_logic(self, mock_cache):
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
rv = self.fetch_events("1607471525180-0")
response = json.loads(rv.data.decode("utf-8"))
@ -60,12 +71,8 @@ class TestAsyncEventApi(SupersetTestCase):
mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100)
self.assertEqual(response, {"result": []})
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events_results(self, mock_uuid4):
app._got_first_request = False
async_query_manager.init_app(app)
self.login(ADMIN_USERNAME)
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
def _test_events_results_logic(self, mock_cache):
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
mock_xrange.return_value = [
(
"1607477697866-0",
@ -110,6 +117,20 @@ class TestAsyncEventApi(SupersetTestCase):
}
self.assertEqual(response, expected)
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events_redis_cache_backend(self, mock_uuid4):
self.run_test_with_cache_backend(RedisCacheBackend, self._test_events_logic)
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events_redis_sentinel_cache_backend(self, mock_uuid4):
self.run_test_with_cache_backend(
RedisSentinelCacheBackend, self._test_events_logic
)
@mock.patch("uuid.uuid4", return_value=UUID)
def test_events_redis(self, mock_uuid4):
self.run_test_with_cache_backend(redis.Redis, self._test_events_logic)
def test_events_no_login(self):
app._got_first_request = False
async_query_manager.init_app(app)

View File

@ -20,8 +20,14 @@ from unittest import mock
from uuid import uuid4
import pytest
import redis
from celery.exceptions import SoftTimeLimitExceeded
from parameterized import parameterized
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.commands.chart.exceptions import ChartDataQueryFailedError
from superset.exceptions import SupersetException
@ -38,17 +44,29 @@ from tests.integration_tests.fixtures.tags import (
from tests.integration_tests.test_app import app
@pytest.mark.usefixtures(
"load_birth_names_data", "load_birth_names_dashboard_with_slices"
)
class TestAsyncQueries(SupersetTestCase):
@pytest.mark.usefixtures(
"load_birth_names_data", "load_birth_names_dashboard_with_slices"
@parameterized.expand(
[
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
("redis.Redis", mock.Mock(spec=redis.Redis)),
]
)
@mock.patch.object(async_query_manager, "update_job")
@mock.patch("superset.tasks.async_queries.set_form_data")
def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job):
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache(
self, cache_type, cache_backend, mock_update_job, mock_set_form_data
):
from superset.tasks.async_queries import load_chart_data_into_cache
app._got_first_request = False
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
job_metadata = {
@ -60,20 +78,33 @@ class TestAsyncQueries(SupersetTestCase):
}
load_chart_data_into_cache(job_metadata, query_context)
mock_set_form_data.assert_called_once_with(query_context)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)
@parameterized.expand(
[
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
("redis.Redis", mock.Mock(spec=redis.Redis)),
]
)
@mock.patch.object(
ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
)
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
def test_load_chart_data_into_cache_error(
self, cache_type, cache_backend, mock_update_job, mock_run_command
):
from superset.tasks.async_queries import load_chart_data_into_cache
app._got_first_request = False
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
user = security_manager.find_user("gamma")
job_metadata = {
@ -90,15 +121,25 @@ class TestAsyncQueries(SupersetTestCase):
errors = [{"message": "Error: foo"}]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
@parameterized.expand(
[
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
("redis.Redis", mock.Mock(spec=redis.Redis)),
]
)
@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_chart_data_into_cache(
self, mock_update_job, mock_run_command
self, cache_type, cache_backend, mock_update_job, mock_run_command
):
from superset.tasks.async_queries import load_chart_data_into_cache
app._got_first_request = False
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
@ -118,13 +159,25 @@ class TestAsyncQueries(SupersetTestCase):
load_chart_data_into_cache(job_metadata, form_data)
set_form_data.assert_called_once_with(form_data, "error", errors=errors)
@parameterized.expand(
[
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
("redis.Redis", mock.Mock(spec=redis.Redis)),
]
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
def test_load_explore_json_into_cache(
self, cache_type, cache_backend, mock_update_job
):
from superset.tasks.async_queries import load_explore_json_into_cache
app._got_first_request = False
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
async_query_manager.init_app(app)
table = self.get_table(name="birth_names")
user = security_manager.find_user("gamma")
form_data = {
@ -146,19 +199,30 @@ class TestAsyncQueries(SupersetTestCase):
}
load_explore_json_into_cache(job_metadata, form_data)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)
@parameterized.expand(
[
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
("redis.Redis", mock.Mock(spec=redis.Redis)),
]
)
@mock.patch.object(async_query_manager, "update_job")
@mock.patch("superset.tasks.async_queries.set_form_data")
def test_load_explore_json_into_cache_error(
self, mock_set_form_data, mock_update_job
self, cache_type, cache_backend, mock_set_form_data, mock_update_job
):
from superset.tasks.async_queries import load_explore_json_into_cache
app._got_first_request = False
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
@ -176,15 +240,25 @@ class TestAsyncQueries(SupersetTestCase):
errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
@parameterized.expand(
[
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
("redis.Redis", mock.Mock(spec=redis.Redis)),
]
)
@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_explore_json_into_cache(
self, mock_update_job, mock_run_command
self, cache_type, cache_backend, mock_update_job, mock_run_command
):
from superset.tasks.async_queries import load_explore_json_into_cache
app._got_first_request = False
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
@ -194,7 +268,7 @@ class TestAsyncQueries(SupersetTestCase):
"status": "pending",
"errors": [],
}
errors = ["A timeout occurred while loading explore json, error"]
errors = ["A timeout occurred while loading explore JSON data"]
with pytest.raises(SoftTimeLimitExceeded):
with mock.patch(

View File

@ -17,15 +17,20 @@
from unittest import mock
from unittest.mock import ANY, Mock
import redis
from flask import g
from jwt import encode
from pytest import fixture, raises
from pytest import fixture, mark, raises
from superset import security_manager
from superset.async_events.async_query_manager import (
AsyncQueryManager,
AsyncQueryTokenException,
)
from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
JWT_TOKEN_SECRET = "some_secret"
JWT_TOKEN_COOKIE_NAME = "superset_async_jwt"
@ -36,7 +41,6 @@ def async_query_manager():
query_manager = AsyncQueryManager()
query_manager._jwt_secret = JWT_TOKEN_SECRET
query_manager._jwt_cookie_name = JWT_TOKEN_COOKIE_NAME
return query_manager
@ -75,12 +79,24 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
async_query_manager.parse_channel_id_from_request(request)
@mark.parametrize(
"cache_type, cache_backend",
[
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
("redis.Redis", mock.Mock(spec=redis.Redis)),
],
)
@mock.patch("superset.is_feature_enabled")
def test_submit_chart_data_job_as_guest_user(
is_feature_enabled_mock, async_query_manager
is_feature_enabled_mock, async_query_manager, cache_type, cache_backend
):
is_feature_enabled_mock.return_value = True
set_current_as_guest_user()
# Mock the get_cache_backend method to return the current cache backend
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
job_mock = Mock()
async_query_manager._load_chart_data_into_cache_job = job_mock
job_meta = async_query_manager.submit_chart_data_job(
@ -105,14 +121,27 @@ def test_submit_chart_data_job_as_guest_user(
)
assert "guest_token" not in job_meta
job_mock.reset_mock() # Reset the mock for the next iteration
@mark.parametrize(
"cache_type, cache_backend",
[
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
("redis.Redis", mock.Mock(spec=redis.Redis)),
],
)
@mock.patch("superset.is_feature_enabled")
def test_submit_explore_json_job_as_guest_user(
is_feature_enabled_mock, async_query_manager
is_feature_enabled_mock, async_query_manager, cache_type, cache_backend
):
is_feature_enabled_mock.return_value = True
set_current_as_guest_user()
# Mock the get_cache_backend method to return the current cache backend
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
job_mock = Mock()
async_query_manager._load_explore_json_into_cache_job = job_mock
job_meta = async_query_manager.submit_explore_json_job(