feat(GAQ): Add Redis Sentinel Support for Global Async Queries (#29912)
Co-authored-by: Sivarajan Narayanan <narayanan_sivarajan@apple.com>
This commit is contained in:
parent
cd6b8b2f6d
commit
103cd3d6f3
|
|
@ -14,20 +14,31 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import jwt
|
||||
import redis
|
||||
from flask import Flask, Request, request, Response, session
|
||||
from flask_caching.backends.base import BaseCache
|
||||
|
||||
from superset.async_events.cache_backend import (
|
||||
RedisCacheBackend,
|
||||
RedisSentinelCacheBackend,
|
||||
)
|
||||
from superset.utils import json
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheBackendNotInitialized(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncQueryTokenException(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -55,13 +66,32 @@ def parse_event(event_data: tuple[str, dict[str, Any]]) -> dict[str, Any]:
|
|||
return {"id": event_id, **json.loads(event_payload)}
|
||||
|
||||
|
||||
def increment_id(redis_id: str) -> str:
|
||||
def increment_id(entry_id: str) -> str:
|
||||
# redis stream IDs are in this format: '1607477697866-0'
|
||||
try:
|
||||
prefix, last = redis_id[:-1], int(redis_id[-1])
|
||||
prefix, last = entry_id[:-1], int(entry_id[-1])
|
||||
return prefix + str(last + 1)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return redis_id
|
||||
return entry_id
|
||||
|
||||
|
||||
def get_cache_backend(
|
||||
config: dict[str, Any],
|
||||
) -> Union[RedisCacheBackend, RedisSentinelCacheBackend, redis.Redis]: # type: ignore
|
||||
cache_config = config.get("GLOBAL_ASYNC_QUERIES_CACHE_BACKEND", {})
|
||||
cache_type = cache_config.get("CACHE_TYPE")
|
||||
|
||||
if cache_type == "RedisCache":
|
||||
return RedisCacheBackend.from_config(cache_config)
|
||||
|
||||
if cache_type == "RedisSentinelCache":
|
||||
return RedisSentinelCacheBackend.from_config(cache_config)
|
||||
|
||||
# TODO: Deprecate hardcoded plain Redis code and expand cache backend options.
|
||||
# Maintain backward compatibility with 'GLOBAL_ASYNC_QUERIES_REDIS_CONFIG' until it is deprecated.
|
||||
return redis.Redis(
|
||||
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
|
||||
)
|
||||
|
||||
|
||||
class AsyncQueryManager:
|
||||
|
|
@ -73,7 +103,7 @@ class AsyncQueryManager:
|
|||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._redis: redis.Redis # type: ignore
|
||||
self._cache: Optional[BaseCache] = None
|
||||
self._stream_prefix: str = ""
|
||||
self._stream_limit: Optional[int]
|
||||
self._stream_limit_firehose: Optional[int]
|
||||
|
|
@ -88,10 +118,9 @@ class AsyncQueryManager:
|
|||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
config = app.config
|
||||
if (
|
||||
config["CACHE_CONFIG"]["CACHE_TYPE"] == "null"
|
||||
or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null"
|
||||
):
|
||||
cache_type = config.get("CACHE_CONFIG", {}).get("CACHE_TYPE")
|
||||
data_cache_type = config.get("DATA_CACHE_CONFIG", {}).get("CACHE_TYPE")
|
||||
if cache_type in [None, "null"] or data_cache_type in [None, "null"]:
|
||||
raise Exception( # pylint: disable=broad-exception-raised
|
||||
"""
|
||||
Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured
|
||||
|
|
@ -99,14 +128,14 @@ class AsyncQueryManager:
|
|||
"""
|
||||
)
|
||||
|
||||
self._cache = get_cache_backend(config)
|
||||
logger.debug("Using GAQ Cache backend as %s", type(self._cache).__name__)
|
||||
|
||||
if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32:
|
||||
raise AsyncQueryTokenException(
|
||||
"Please provide a JWT secret at least 32 bytes long"
|
||||
)
|
||||
|
||||
self._redis = redis.Redis(
|
||||
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
|
||||
)
|
||||
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
|
||||
self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"]
|
||||
self._stream_limit_firehose = config[
|
||||
|
|
@ -230,14 +259,35 @@ class AsyncQueryManager:
|
|||
def read_events(
|
||||
self, channel: str, last_id: Optional[str]
|
||||
) -> list[Optional[dict[str, Any]]]:
|
||||
if not self._cache:
|
||||
raise CacheBackendNotInitialized("Cache backend not initialized")
|
||||
|
||||
stream_name = f"{self._stream_prefix}{channel}"
|
||||
start_id = increment_id(last_id) if last_id else "-"
|
||||
results = self._redis.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT)
|
||||
results = self._cache.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT)
|
||||
# Decode bytes to strings, decode_responses is not supported at RedisCache and RedisSentinelCache
|
||||
if isinstance(self._cache, (RedisSentinelCacheBackend, RedisCacheBackend)):
|
||||
decoded_results = [
|
||||
(
|
||||
event_id.decode("utf-8"),
|
||||
{
|
||||
key.decode("utf-8"): value.decode("utf-8")
|
||||
for key, value in event_data.items()
|
||||
},
|
||||
)
|
||||
for event_id, event_data in results
|
||||
]
|
||||
return (
|
||||
[] if not decoded_results else list(map(parse_event, decoded_results))
|
||||
)
|
||||
return [] if not results else list(map(parse_event, results))
|
||||
|
||||
def update_job(
|
||||
self, job_metadata: dict[str, Any], status: str, **kwargs: Any
|
||||
) -> None:
|
||||
if not self._cache:
|
||||
raise CacheBackendNotInitialized("Cache backend not initialized")
|
||||
|
||||
if "channel_id" not in job_metadata:
|
||||
raise AsyncQueryJobException("No channel ID specified")
|
||||
|
||||
|
|
@ -253,5 +303,5 @@ class AsyncQueryManager:
|
|||
logger.debug("********** logging event data to stream %s", scoped_stream_name)
|
||||
logger.debug(event_data)
|
||||
|
||||
self._redis.xadd(scoped_stream_name, event_data, "*", self._stream_limit)
|
||||
self._redis.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose)
|
||||
self._cache.xadd(scoped_stream_name, event_data, "*", self._stream_limit)
|
||||
self._cache.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,209 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import redis
|
||||
from flask_caching.backends.rediscache import RedisCache, RedisSentinelCache
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
|
||||
class RedisCacheBackend(RedisCache):
|
||||
MAX_EVENT_COUNT = 100
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
password: Optional[str] = None,
|
||||
db: int = 0,
|
||||
default_timeout: int = 300,
|
||||
key_prefix: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_cert_reqs: str = "required",
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
db=db,
|
||||
default_timeout=default_timeout,
|
||||
key_prefix=key_prefix,
|
||||
**kwargs,
|
||||
)
|
||||
self._cache = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
db=db,
|
||||
ssl=ssl,
|
||||
ssl_certfile=ssl_certfile,
|
||||
ssl_keyfile=ssl_keyfile,
|
||||
ssl_cert_reqs=ssl_cert_reqs,
|
||||
ssl_ca_certs=ssl_ca_certs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def xadd(
|
||||
self,
|
||||
stream_name: str,
|
||||
event_data: Dict[str, Any],
|
||||
event_id: str = "*",
|
||||
maxlen: Optional[int] = None,
|
||||
) -> str:
|
||||
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
|
||||
|
||||
def xrange(
|
||||
self,
|
||||
stream_name: str,
|
||||
start: str = "-",
|
||||
end: str = "+",
|
||||
count: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
count = count or self.MAX_EVENT_COUNT
|
||||
return self._cache.xrange(stream_name, start, end, count)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "RedisCacheBackend":
|
||||
kwargs = {
|
||||
"host": config.get("CACHE_REDIS_HOST", "localhost"),
|
||||
"port": config.get("CACHE_REDIS_PORT", 6379),
|
||||
"db": config.get("CACHE_REDIS_DB", 0),
|
||||
"password": config.get("CACHE_REDIS_PASSWORD", None),
|
||||
"key_prefix": config.get("CACHE_KEY_PREFIX", None),
|
||||
"default_timeout": config.get("CACHE_DEFAULT_TIMEOUT", 300),
|
||||
"ssl": config.get("CACHE_REDIS_SSL", False),
|
||||
"ssl_certfile": config.get("CACHE_REDIS_SSL_CERTFILE", None),
|
||||
"ssl_keyfile": config.get("CACHE_REDIS_SSL_KEYFILE", None),
|
||||
"ssl_cert_reqs": config.get("CACHE_REDIS_SSL_CERT_REQS", "required"),
|
||||
"ssl_ca_certs": config.get("CACHE_REDIS_SSL_CA_CERTS", None),
|
||||
}
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
class RedisSentinelCacheBackend(RedisSentinelCache):
|
||||
MAX_EVENT_COUNT = 100
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
sentinels: List[Tuple[str, int]],
|
||||
master: str,
|
||||
password: Optional[str] = None,
|
||||
sentinel_password: Optional[str] = None,
|
||||
db: int = 0,
|
||||
default_timeout: int = 300,
|
||||
key_prefix: str = "",
|
||||
ssl: bool = False,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_cert_reqs: str = "required",
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Sentinel dont directly support SSL
|
||||
# Initialize Sentinel without SSL parameters
|
||||
self._sentinel = Sentinel(
|
||||
sentinels,
|
||||
sentinel_kwargs={
|
||||
"password": sentinel_password,
|
||||
},
|
||||
**{
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k
|
||||
not in [
|
||||
"ssl",
|
||||
"ssl_certfile",
|
||||
"ssl_keyfile",
|
||||
"ssl_cert_reqs",
|
||||
"ssl_ca_certs",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
# Prepare SSL-related arguments for master_for method
|
||||
master_kwargs = {
|
||||
"password": password,
|
||||
"ssl": ssl,
|
||||
"ssl_certfile": ssl_certfile if ssl else None,
|
||||
"ssl_keyfile": ssl_keyfile if ssl else None,
|
||||
"ssl_cert_reqs": ssl_cert_reqs if ssl else None,
|
||||
"ssl_ca_certs": ssl_ca_certs if ssl else None,
|
||||
}
|
||||
|
||||
# If SSL is False, remove all SSL-related keys
|
||||
# SSL_* are expected only if SSL is True
|
||||
if not ssl:
|
||||
master_kwargs = {
|
||||
k: v for k, v in master_kwargs.items() if not k.startswith("ssl")
|
||||
}
|
||||
|
||||
# Filter out None values from master_kwargs
|
||||
master_kwargs = {k: v for k, v in master_kwargs.items() if v is not None}
|
||||
|
||||
# Initialize Redis master connection
|
||||
self._cache = self._sentinel.master_for(master, **master_kwargs)
|
||||
|
||||
# Call the parent class constructor
|
||||
super().__init__(
|
||||
host=None,
|
||||
port=None,
|
||||
password=password,
|
||||
db=db,
|
||||
default_timeout=default_timeout,
|
||||
key_prefix=key_prefix,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def xadd(
|
||||
self,
|
||||
stream_name: str,
|
||||
event_data: Dict[str, Any],
|
||||
event_id: str = "*",
|
||||
maxlen: Optional[int] = None,
|
||||
) -> str:
|
||||
return self._cache.xadd(stream_name, event_data, event_id, maxlen)
|
||||
|
||||
def xrange(
|
||||
self,
|
||||
stream_name: str,
|
||||
start: str = "-",
|
||||
end: str = "+",
|
||||
count: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
count = count or self.MAX_EVENT_COUNT
|
||||
return self._cache.xrange(stream_name, start, end, count)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "RedisSentinelCacheBackend":
|
||||
kwargs = {
|
||||
"sentinels": config.get("CACHE_REDIS_SENTINELS", [("127.0.0.1", 26379)]),
|
||||
"master": config.get("CACHE_REDIS_SENTINEL_MASTER", "mymaster"),
|
||||
"password": config.get("CACHE_REDIS_PASSWORD", None),
|
||||
"sentinel_password": config.get("CACHE_REDIS_SENTINEL_PASSWORD", None),
|
||||
"key_prefix": config.get("CACHE_KEY_PREFIX", ""),
|
||||
"db": config.get("CACHE_REDIS_DB", 0),
|
||||
"ssl": config.get("CACHE_REDIS_SSL", False),
|
||||
"ssl_certfile": config.get("CACHE_REDIS_SSL_CERTFILE", None),
|
||||
"ssl_keyfile": config.get("CACHE_REDIS_SSL_KEYFILE", None),
|
||||
"ssl_cert_reqs": config.get("CACHE_REDIS_SSL_CERT_REQS", "required"),
|
||||
"ssl_ca_certs": config.get("CACHE_REDIS_SSL_CA_CERTS", None),
|
||||
}
|
||||
return cls(**kwargs)
|
||||
|
|
@ -1690,6 +1690,28 @@ GLOBAL_ASYNC_QUERIES_POLLING_DELAY = int(
|
|||
)
|
||||
GLOBAL_ASYNC_QUERIES_WEBSOCKET_URL = "ws://127.0.0.1:8080/"
|
||||
|
||||
# Global async queries cache backend configuration options:
|
||||
# - Set 'CACHE_TYPE' to 'RedisCache' for RedisCacheBackend.
|
||||
# - Set 'CACHE_TYPE' to 'RedisSentinelCache' for RedisSentinelCacheBackend.
|
||||
# - Set 'CACHE_TYPE' to 'None' to fall back on 'GLOBAL_ASYNC_QUERIES_REDIS_CONFIG'.
|
||||
GLOBAL_ASYNC_QUERIES_CACHE_BACKEND = {
|
||||
"CACHE_TYPE": "RedisCache",
|
||||
"CACHE_REDIS_HOST": "localhost",
|
||||
"CACHE_REDIS_PORT": 6379,
|
||||
"CACHE_REDIS_USER": "",
|
||||
"CACHE_REDIS_PASSWORD": "",
|
||||
"CACHE_REDIS_DB": 0,
|
||||
"CACHE_DEFAULT_TIMEOUT": 300,
|
||||
"CACHE_REDIS_SENTINELS": [("localhost", 26379)],
|
||||
"CACHE_REDIS_SENTINEL_MASTER": "mymaster",
|
||||
"CACHE_REDIS_SENTINEL_PASSWORD": None,
|
||||
"CACHE_REDIS_SSL": False, # True or False
|
||||
"CACHE_REDIS_SSL_CERTFILE": None,
|
||||
"CACHE_REDIS_SSL_KEYFILE": None,
|
||||
"CACHE_REDIS_SSL_CERT_REQS": "required",
|
||||
"CACHE_REDIS_SSL_CA_CERTS": None,
|
||||
}
|
||||
|
||||
# Embedded config options
|
||||
GUEST_ROLE_NAME = "Public"
|
||||
GUEST_TOKEN_JWT_SECRET = "test-guest-secret-change-me"
|
||||
|
|
|
|||
|
|
@ -14,9 +14,15 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Type
|
||||
from unittest import mock
|
||||
|
||||
import redis
|
||||
|
||||
from superset.async_events.cache_backend import (
|
||||
RedisCacheBackend,
|
||||
RedisSentinelCacheBackend,
|
||||
)
|
||||
from superset.extensions import async_query_manager
|
||||
from superset.utils import json
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
|
|
@ -32,12 +38,21 @@ class TestAsyncEventApi(SupersetTestCase):
|
|||
uri = f"{base_uri}?last_id={last_id}" if last_id else base_uri
|
||||
return self.client.get(uri)
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events(self, mock_uuid4):
|
||||
def run_test_with_cache_backend(self, cache_backend_cls: Type[Any], test_func):
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
|
||||
# Create a mock cache backend instance
|
||||
mock_cache = mock.Mock(spec=cache_backend_cls)
|
||||
|
||||
# Set the mock cache instance
|
||||
async_query_manager._cache = mock_cache
|
||||
|
||||
self.login(ADMIN_USERNAME)
|
||||
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
|
||||
test_func(mock_cache)
|
||||
|
||||
def _test_events_logic(self, mock_cache):
|
||||
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
|
||||
rv = self.fetch_events()
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
|
|
@ -46,12 +61,8 @@ class TestAsyncEventApi(SupersetTestCase):
|
|||
mock_xrange.assert_called_with(channel_id, "-", "+", 100)
|
||||
self.assertEqual(response, {"result": []})
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events_last_id(self, mock_uuid4):
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
self.login(ADMIN_USERNAME)
|
||||
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
|
||||
def _test_events_last_id_logic(self, mock_cache):
|
||||
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
|
||||
rv = self.fetch_events("1607471525180-0")
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
|
|
@ -60,12 +71,8 @@ class TestAsyncEventApi(SupersetTestCase):
|
|||
mock_xrange.assert_called_with(channel_id, "1607471525180-1", "+", 100)
|
||||
self.assertEqual(response, {"result": []})
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events_results(self, mock_uuid4):
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
self.login(ADMIN_USERNAME)
|
||||
with mock.patch.object(async_query_manager._redis, "xrange") as mock_xrange:
|
||||
def _test_events_results_logic(self, mock_cache):
|
||||
with mock.patch.object(mock_cache, "xrange") as mock_xrange:
|
||||
mock_xrange.return_value = [
|
||||
(
|
||||
"1607477697866-0",
|
||||
|
|
@ -110,6 +117,20 @@ class TestAsyncEventApi(SupersetTestCase):
|
|||
}
|
||||
self.assertEqual(response, expected)
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events_redis_cache_backend(self, mock_uuid4):
|
||||
self.run_test_with_cache_backend(RedisCacheBackend, self._test_events_logic)
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events_redis_sentinel_cache_backend(self, mock_uuid4):
|
||||
self.run_test_with_cache_backend(
|
||||
RedisSentinelCacheBackend, self._test_events_logic
|
||||
)
|
||||
|
||||
@mock.patch("uuid.uuid4", return_value=UUID)
|
||||
def test_events_redis(self, mock_uuid4):
|
||||
self.run_test_with_cache_backend(redis.Redis, self._test_events_logic)
|
||||
|
||||
def test_events_no_login(self):
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
|
|
|
|||
|
|
@ -20,8 +20,14 @@ from unittest import mock
|
|||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from parameterized import parameterized
|
||||
|
||||
from superset.async_events.cache_backend import (
|
||||
RedisCacheBackend,
|
||||
RedisSentinelCacheBackend,
|
||||
)
|
||||
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
||||
from superset.commands.chart.exceptions import ChartDataQueryFailedError
|
||||
from superset.exceptions import SupersetException
|
||||
|
|
@ -38,17 +44,29 @@ from tests.integration_tests.fixtures.tags import (
|
|||
from tests.integration_tests.test_app import app
|
||||
|
||||
|
||||
@pytest.mark.usefixtures(
|
||||
"load_birth_names_data", "load_birth_names_dashboard_with_slices"
|
||||
)
|
||||
class TestAsyncQueries(SupersetTestCase):
|
||||
@pytest.mark.usefixtures(
|
||||
"load_birth_names_data", "load_birth_names_dashboard_with_slices"
|
||||
@parameterized.expand(
|
||||
[
|
||||
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
|
||||
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
|
||||
("redis.Redis", mock.Mock(spec=redis.Redis)),
|
||||
]
|
||||
)
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
@mock.patch("superset.tasks.async_queries.set_form_data")
|
||||
def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job):
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_chart_data_into_cache(
|
||||
self, cache_type, cache_backend, mock_update_job, mock_set_form_data
|
||||
):
|
||||
from superset.tasks.async_queries import load_chart_data_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
|
||||
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
|
||||
async_query_manager.init_app(app)
|
||||
|
||||
query_context = get_query_context("birth_names")
|
||||
user = security_manager.find_user("gamma")
|
||||
job_metadata = {
|
||||
|
|
@ -60,20 +78,33 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
}
|
||||
|
||||
load_chart_data_into_cache(job_metadata, query_context)
|
||||
|
||||
mock_set_form_data.assert_called_once_with(query_context)
|
||||
mock_update_job.assert_called_once_with(
|
||||
job_metadata, "done", result_url=mock.ANY
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
|
||||
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
|
||||
("redis.Redis", mock.Mock(spec=redis.Redis)),
|
||||
]
|
||||
)
|
||||
@mock.patch.object(
|
||||
ChartDataCommand, "run", side_effect=ChartDataQueryFailedError("Error: foo")
|
||||
)
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
|
||||
def test_load_chart_data_into_cache_error(
|
||||
self, cache_type, cache_backend, mock_update_job, mock_run_command
|
||||
):
|
||||
from superset.tasks.async_queries import load_chart_data_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
|
||||
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
|
||||
async_query_manager.init_app(app)
|
||||
|
||||
query_context = get_query_context("birth_names")
|
||||
user = security_manager.find_user("gamma")
|
||||
job_metadata = {
|
||||
|
|
@ -90,15 +121,25 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
errors = [{"message": "Error: foo"}]
|
||||
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
|
||||
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
|
||||
("redis.Redis", mock.Mock(spec=redis.Redis)),
|
||||
]
|
||||
)
|
||||
@mock.patch.object(ChartDataCommand, "run")
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_soft_timeout_load_chart_data_into_cache(
|
||||
self, mock_update_job, mock_run_command
|
||||
self, cache_type, cache_backend, mock_update_job, mock_run_command
|
||||
):
|
||||
from superset.tasks.async_queries import load_chart_data_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
|
||||
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
|
||||
async_query_manager.init_app(app)
|
||||
|
||||
user = security_manager.find_user("gamma")
|
||||
form_data = {}
|
||||
job_metadata = {
|
||||
|
|
@ -118,13 +159,25 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
load_chart_data_into_cache(job_metadata, form_data)
|
||||
set_form_data.assert_called_once_with(form_data, "error", errors=errors)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
|
||||
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
|
||||
("redis.Redis", mock.Mock(spec=redis.Redis)),
|
||||
]
|
||||
)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_explore_json_into_cache(self, mock_update_job):
|
||||
def test_load_explore_json_into_cache(
|
||||
self, cache_type, cache_backend, mock_update_job
|
||||
):
|
||||
from superset.tasks.async_queries import load_explore_json_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
|
||||
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
|
||||
async_query_manager.init_app(app)
|
||||
|
||||
table = self.get_table(name="birth_names")
|
||||
user = security_manager.find_user("gamma")
|
||||
form_data = {
|
||||
|
|
@ -146,19 +199,30 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
}
|
||||
|
||||
load_explore_json_into_cache(job_metadata, form_data)
|
||||
|
||||
mock_update_job.assert_called_once_with(
|
||||
job_metadata, "done", result_url=mock.ANY
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
|
||||
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
|
||||
("redis.Redis", mock.Mock(spec=redis.Redis)),
|
||||
]
|
||||
)
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
@mock.patch("superset.tasks.async_queries.set_form_data")
|
||||
def test_load_explore_json_into_cache_error(
|
||||
self, mock_set_form_data, mock_update_job
|
||||
self, cache_type, cache_backend, mock_set_form_data, mock_update_job
|
||||
):
|
||||
from superset.tasks.async_queries import load_explore_json_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
|
||||
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
|
||||
async_query_manager.init_app(app)
|
||||
|
||||
user = security_manager.find_user("gamma")
|
||||
form_data = {}
|
||||
job_metadata = {
|
||||
|
|
@ -176,15 +240,25 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
errors = ["The dataset associated with this chart no longer exists"]
|
||||
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
|
||||
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
|
||||
("redis.Redis", mock.Mock(spec=redis.Redis)),
|
||||
]
|
||||
)
|
||||
@mock.patch.object(ChartDataCommand, "run")
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_soft_timeout_load_explore_json_into_cache(
|
||||
self, mock_update_job, mock_run_command
|
||||
self, cache_type, cache_backend, mock_update_job, mock_run_command
|
||||
):
|
||||
from superset.tasks.async_queries import load_explore_json_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
|
||||
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
|
||||
async_query_manager.init_app(app)
|
||||
|
||||
user = security_manager.find_user("gamma")
|
||||
form_data = {}
|
||||
job_metadata = {
|
||||
|
|
@ -194,7 +268,7 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
"status": "pending",
|
||||
"errors": [],
|
||||
}
|
||||
errors = ["A timeout occurred while loading explore json, error"]
|
||||
errors = ["A timeout occurred while loading explore JSON data"]
|
||||
|
||||
with pytest.raises(SoftTimeLimitExceeded):
|
||||
with mock.patch(
|
||||
|
|
|
|||
|
|
@ -17,15 +17,20 @@
|
|||
from unittest import mock
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import redis
|
||||
from flask import g
|
||||
from jwt import encode
|
||||
from pytest import fixture, raises
|
||||
from pytest import fixture, mark, raises
|
||||
|
||||
from superset import security_manager
|
||||
from superset.async_events.async_query_manager import (
|
||||
AsyncQueryManager,
|
||||
AsyncQueryTokenException,
|
||||
)
|
||||
from superset.async_events.cache_backend import (
|
||||
RedisCacheBackend,
|
||||
RedisSentinelCacheBackend,
|
||||
)
|
||||
|
||||
JWT_TOKEN_SECRET = "some_secret"
|
||||
JWT_TOKEN_COOKIE_NAME = "superset_async_jwt"
|
||||
|
|
@ -36,7 +41,6 @@ def async_query_manager():
|
|||
query_manager = AsyncQueryManager()
|
||||
query_manager._jwt_secret = JWT_TOKEN_SECRET
|
||||
query_manager._jwt_cookie_name = JWT_TOKEN_COOKIE_NAME
|
||||
|
||||
return query_manager
|
||||
|
||||
|
||||
|
|
@ -75,12 +79,24 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
|
|||
async_query_manager.parse_channel_id_from_request(request)
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
"cache_type, cache_backend",
|
||||
[
|
||||
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
|
||||
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
|
||||
("redis.Redis", mock.Mock(spec=redis.Redis)),
|
||||
],
|
||||
)
|
||||
@mock.patch("superset.is_feature_enabled")
|
||||
def test_submit_chart_data_job_as_guest_user(
|
||||
is_feature_enabled_mock, async_query_manager
|
||||
is_feature_enabled_mock, async_query_manager, cache_type, cache_backend
|
||||
):
|
||||
is_feature_enabled_mock.return_value = True
|
||||
set_current_as_guest_user()
|
||||
|
||||
# Mock the get_cache_backend method to return the current cache backend
|
||||
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
|
||||
|
||||
job_mock = Mock()
|
||||
async_query_manager._load_chart_data_into_cache_job = job_mock
|
||||
job_meta = async_query_manager.submit_chart_data_job(
|
||||
|
|
@ -105,14 +121,27 @@ def test_submit_chart_data_job_as_guest_user(
|
|||
)
|
||||
|
||||
assert "guest_token" not in job_meta
|
||||
job_mock.reset_mock() # Reset the mock for the next iteration
|
||||
|
||||
|
||||
@mark.parametrize(
|
||||
"cache_type, cache_backend",
|
||||
[
|
||||
("RedisCacheBackend", mock.Mock(spec=RedisCacheBackend)),
|
||||
("RedisSentinelCacheBackend", mock.Mock(spec=RedisSentinelCacheBackend)),
|
||||
("redis.Redis", mock.Mock(spec=redis.Redis)),
|
||||
],
|
||||
)
|
||||
@mock.patch("superset.is_feature_enabled")
|
||||
def test_submit_explore_json_job_as_guest_user(
|
||||
is_feature_enabled_mock, async_query_manager
|
||||
is_feature_enabled_mock, async_query_manager, cache_type, cache_backend
|
||||
):
|
||||
is_feature_enabled_mock.return_value = True
|
||||
set_current_as_guest_user()
|
||||
|
||||
# Mock the get_cache_backend method to return the current cache backend
|
||||
async_query_manager.get_cache_backend = mock.Mock(return_value=cache_backend)
|
||||
|
||||
job_mock = Mock()
|
||||
async_query_manager._load_explore_json_into_cache_job = job_mock
|
||||
job_meta = async_query_manager.submit_explore_json_job(
|
||||
|
|
|
|||
Loading…
Reference in New Issue