chore(async): Initial Refactoring of Global Async Queries (#25466)
This commit is contained in:
parent
36ed617090
commit
db7f5fed31
|
|
@ -38,6 +38,7 @@ type ConfigType = {
|
|||
redisStreamReadBlockMs: number;
|
||||
jwtSecret: string;
|
||||
jwtCookieName: string;
|
||||
jwtChannelIdKey: string;
|
||||
socketResponseTimeoutMs: number;
|
||||
pingSocketsIntervalMs: number;
|
||||
gcChannelsIntervalMs: number;
|
||||
|
|
@ -54,6 +55,7 @@ function defaultConfig(): ConfigType {
|
|||
redisStreamReadBlockMs: 5000,
|
||||
jwtSecret: '',
|
||||
jwtCookieName: 'async-token',
|
||||
jwtChannelIdKey: 'channel',
|
||||
socketResponseTimeoutMs: 60 * 1000,
|
||||
pingSocketsIntervalMs: 20 * 1000,
|
||||
gcChannelsIntervalMs: 120 * 1000,
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ interface EventValue {
|
|||
result_url?: string;
|
||||
}
|
||||
interface JwtPayload {
|
||||
channel: string;
|
||||
[key: string]: string;
|
||||
}
|
||||
interface FetchRangeFromStreamParams {
|
||||
sessionId: string;
|
||||
|
|
@ -253,14 +253,20 @@ export const processStreamResults = (results: StreamResult[]): void => {
|
|||
|
||||
/**
|
||||
* Verify and parse a JWT cookie from an HTTP request.
|
||||
* Returns the JWT payload or throws an error on invalid token.
|
||||
* Returns the channelId from the JWT payload found in the cookie
|
||||
* configured via 'jwtCookieName' in the config.
|
||||
*/
|
||||
const getJwtPayload = (request: http.IncomingMessage): JwtPayload => {
|
||||
const readChannelId = (request: http.IncomingMessage): string => {
|
||||
const cookies = cookie.parse(request.headers.cookie || '');
|
||||
const token = cookies[opts.jwtCookieName];
|
||||
|
||||
if (!token) throw new Error('JWT not present');
|
||||
return jwt.verify(token, opts.jwtSecret) as JwtPayload;
|
||||
const jwtPayload = jwt.verify(token, opts.jwtSecret) as JwtPayload;
|
||||
const channelId = jwtPayload[opts.jwtChannelIdKey];
|
||||
|
||||
if (!channelId) throw new Error('Channel ID not present in JWT');
|
||||
|
||||
return channelId;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -286,8 +292,7 @@ export const incrementId = (id: string): string => {
|
|||
* WebSocket `connection` event handler, called via wss
|
||||
*/
|
||||
export const wsConnection = (ws: WebSocket, request: http.IncomingMessage) => {
|
||||
const jwtPayload: JwtPayload = getJwtPayload(request);
|
||||
const channel: string = jwtPayload.channel;
|
||||
const channel: string = readChannelId(request);
|
||||
const socketInstance: SocketInstance = { ws, channel, pongTs: Date.now() };
|
||||
|
||||
// add this ws instance to the internal registry
|
||||
|
|
@ -351,8 +356,7 @@ export const httpUpgrade = (
|
|||
head: Buffer,
|
||||
) => {
|
||||
try {
|
||||
const jwtPayload: JwtPayload = getJwtPayload(request);
|
||||
if (!jwtPayload.channel) throw new Error('Channel ID not present');
|
||||
readChannelId(request);
|
||||
} catch (err) {
|
||||
// JWT invalid, do not establish a WebSocket connection
|
||||
logger.error(err);
|
||||
|
|
|
|||
|
|
@ -88,9 +88,9 @@ class AsyncEventsRestApi(BaseSupersetApi):
|
|||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
async_channel_id = async_query_manager.parse_jwt_from_request(request)[
|
||||
"channel"
|
||||
]
|
||||
async_channel_id = async_query_manager.parse_channel_id_from_request(
|
||||
request
|
||||
)
|
||||
last_event_id = request.args.get("last_id")
|
||||
events = async_query_manager.read_events(async_channel_id, last_event_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -82,6 +82,9 @@ class AsyncQueryManager:
|
|||
self._jwt_cookie_domain: Optional[str]
|
||||
self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None
|
||||
self._jwt_secret: str
|
||||
self._load_chart_data_into_cache_job: Any = None
|
||||
# pylint: disable=invalid-name
|
||||
self._load_explore_json_into_cache_job: Any = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
config = app.config
|
||||
|
|
@ -115,6 +118,19 @@ class AsyncQueryManager:
|
|||
self._jwt_cookie_domain = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"]
|
||||
self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
|
||||
|
||||
if config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]:
|
||||
self.register_request_handlers(app)
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.tasks.async_queries import (
|
||||
load_chart_data_into_cache,
|
||||
load_explore_json_into_cache,
|
||||
)
|
||||
|
||||
self._load_chart_data_into_cache_job = load_chart_data_into_cache
|
||||
self._load_explore_json_into_cache_job = load_explore_json_into_cache
|
||||
|
||||
def register_request_handlers(self, app: Flask) -> None:
|
||||
@app.after_request
|
||||
def validate_session(response: Response) -> Response:
|
||||
user_id = get_user_id()
|
||||
|
|
@ -149,13 +165,13 @@ class AsyncQueryManager:
|
|||
|
||||
return response
|
||||
|
||||
def parse_jwt_from_request(self, req: Request) -> dict[str, Any]:
|
||||
def parse_channel_id_from_request(self, req: Request) -> str:
|
||||
token = req.cookies.get(self._jwt_cookie_name)
|
||||
if not token:
|
||||
raise AsyncQueryTokenException("Token not preset")
|
||||
|
||||
try:
|
||||
return jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
|
||||
return jwt.decode(token, self._jwt_secret, algorithms=["HS256"])["channel"]
|
||||
except Exception as ex:
|
||||
logger.warning("Parse jwt failed", exc_info=True)
|
||||
raise AsyncQueryTokenException("Failed to parse token") from ex
|
||||
|
|
@ -166,6 +182,31 @@ class AsyncQueryManager:
|
|||
channel_id, job_id, user_id, status=self.STATUS_PENDING
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def submit_explore_json_job(
|
||||
self,
|
||||
channel_id: str,
|
||||
form_data: dict[str, Any],
|
||||
response_type: str,
|
||||
force: Optional[bool] = False,
|
||||
user_id: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
job_metadata = self.init_job(channel_id, user_id)
|
||||
self._load_explore_json_into_cache_job.delay(
|
||||
job_metadata,
|
||||
form_data,
|
||||
response_type,
|
||||
force,
|
||||
)
|
||||
return job_metadata
|
||||
|
||||
def submit_chart_data_job(
|
||||
self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int]
|
||||
) -> dict[str, Any]:
|
||||
job_metadata = self.init_job(channel_id, user_id)
|
||||
self._load_chart_data_into_cache_job.delay(job_metadata, form_data)
|
||||
return job_metadata
|
||||
|
||||
def read_events(
|
||||
self, channel: str, last_id: Optional[str]
|
||||
) -> list[Optional[dict[str, Any]]]:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from typing import Any, Optional
|
|||
from flask import Request
|
||||
|
||||
from superset.extensions import async_query_manager
|
||||
from superset.tasks.async_queries import load_chart_data_into_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -29,10 +28,11 @@ class CreateAsyncChartDataJobCommand:
|
|||
_async_channel_id: str
|
||||
|
||||
def validate(self, request: Request) -> None:
|
||||
jwt_data = async_query_manager.parse_jwt_from_request(request)
|
||||
self._async_channel_id = jwt_data["channel"]
|
||||
self._async_channel_id = async_query_manager.parse_channel_id_from_request(
|
||||
request
|
||||
)
|
||||
|
||||
def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]:
|
||||
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
|
||||
load_chart_data_into_cache.delay(job_metadata, form_data)
|
||||
return job_metadata
|
||||
return async_query_manager.submit_chart_data_job(
|
||||
self._async_channel_id, form_data, user_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ feature_flags = config.DEFAULT_FEATURE_FLAGS.copy()
|
|||
feature_flags.update(config.FEATURE_FLAGS)
|
||||
feature_flags_func = config.GET_FEATURE_FLAGS_FUNC
|
||||
if feature_flags_func:
|
||||
# pylint: disable=not-callable
|
||||
try:
|
||||
# pylint: disable=not-callable
|
||||
feature_flags = feature_flags_func(feature_flags)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# bypass any feature flags that depend on context
|
||||
|
|
|
|||
|
|
@ -1524,6 +1524,7 @@ GLOBAL_ASYNC_QUERIES_REDIS_CONFIG = {
|
|||
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX = "async-events-"
|
||||
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
|
||||
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
|
||||
GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS = True
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
|
||||
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (
|
||||
|
|
|
|||
|
|
@ -74,7 +74,6 @@ from superset.models.sql_lab import Query
|
|||
from superset.models.user_attributes import UserAttribute
|
||||
from superset.sqllab.utils import bootstrap_sqllab_data
|
||||
from superset.superset_typing import FlaskResponse
|
||||
from superset.tasks.async_queries import load_explore_json_into_cache
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.cache import etag_cache
|
||||
from superset.utils.core import (
|
||||
|
|
@ -320,14 +319,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
# at which point they will call the /explore_json/data/<cache_key>
|
||||
# endpoint to retrieve the results.
|
||||
try:
|
||||
async_channel_id = async_query_manager.parse_jwt_from_request(
|
||||
request
|
||||
)["channel"]
|
||||
job_metadata = async_query_manager.init_job(
|
||||
async_channel_id, get_user_id()
|
||||
async_channel_id = (
|
||||
async_query_manager.parse_channel_id_from_request(request)
|
||||
)
|
||||
load_explore_json_into_cache.delay(
|
||||
job_metadata, form_data, response_type, force
|
||||
job_metadata = async_query_manager.submit_explore_json_job(
|
||||
async_channel_id, form_data, response_type, force, get_user_id()
|
||||
)
|
||||
except AsyncQueryTokenException:
|
||||
return json_error_response("Not authorized", 401)
|
||||
|
|
|
|||
|
|
@ -20,18 +20,11 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from flask import g
|
||||
|
||||
from superset.charts.commands.exceptions import ChartDataQueryFailedError
|
||||
from superset.charts.data.commands.get_data_command import ChartDataCommand
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.extensions import async_query_manager, security_manager
|
||||
from superset.tasks import async_queries
|
||||
from superset.tasks.async_queries import (
|
||||
load_chart_data_into_cache,
|
||||
load_explore_json_into_cache,
|
||||
)
|
||||
from superset.utils.core import get_user_id
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
|
|
@ -43,10 +36,14 @@ from tests.integration_tests.test_app import app
|
|||
|
||||
|
||||
class TestAsyncQueries(SupersetTestCase):
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@pytest.mark.usefixtures(
|
||||
"load_birth_names_data", "load_birth_names_dashboard_with_slices"
|
||||
)
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
@mock.patch.object(async_queries, "set_form_data")
|
||||
@mock.patch("superset.tasks.async_queries.set_form_data")
|
||||
def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job):
|
||||
from superset.tasks.async_queries import load_chart_data_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
query_context = get_query_context("birth_names")
|
||||
|
|
@ -70,6 +67,8 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
)
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
|
||||
from superset.tasks.async_queries import load_chart_data_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
query_context = get_query_context("birth_names")
|
||||
|
|
@ -93,6 +92,8 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
def test_soft_timeout_load_chart_data_into_cache(
|
||||
self, mock_update_job, mock_run_command
|
||||
):
|
||||
from superset.tasks.async_queries import load_chart_data_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
user = security_manager.find_user("gamma")
|
||||
|
|
@ -107,9 +108,8 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
errors = ["A timeout occurred while loading chart data"]
|
||||
|
||||
with pytest.raises(SoftTimeLimitExceeded):
|
||||
with mock.patch.object(
|
||||
async_queries,
|
||||
"set_form_data",
|
||||
with mock.patch(
|
||||
"superset.tasks.async_queries.set_form_data"
|
||||
) as set_form_data:
|
||||
set_form_data.side_effect = SoftTimeLimitExceeded()
|
||||
load_chart_data_into_cache(job_metadata, form_data)
|
||||
|
|
@ -118,6 +118,8 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
def test_load_explore_json_into_cache(self, mock_update_job):
|
||||
from superset.tasks.async_queries import load_explore_json_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
table = self.get_table(name="birth_names")
|
||||
|
|
@ -146,10 +148,12 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
)
|
||||
|
||||
@mock.patch.object(async_query_manager, "update_job")
|
||||
@mock.patch.object(async_queries, "set_form_data")
|
||||
@mock.patch("superset.tasks.async_queries.set_form_data")
|
||||
def test_load_explore_json_into_cache_error(
|
||||
self, mock_set_form_data, mock_update_job
|
||||
):
|
||||
from superset.tasks.async_queries import load_explore_json_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
user = security_manager.find_user("gamma")
|
||||
|
|
@ -174,6 +178,8 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
def test_soft_timeout_load_explore_json_into_cache(
|
||||
self, mock_update_job, mock_run_command
|
||||
):
|
||||
from superset.tasks.async_queries import load_explore_json_into_cache
|
||||
|
||||
app._got_first_request = False
|
||||
async_query_manager.init_app(app)
|
||||
user = security_manager.find_user("gamma")
|
||||
|
|
@ -188,9 +194,8 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
errors = ["A timeout occurred while loading explore json, error"]
|
||||
|
||||
with pytest.raises(SoftTimeLimitExceeded):
|
||||
with mock.patch.object(
|
||||
async_queries,
|
||||
"set_form_data",
|
||||
with mock.patch(
|
||||
"superset.tasks.async_queries.set_form_data"
|
||||
) as set_form_data:
|
||||
set_form_data.side_effect = SoftTimeLimitExceeded()
|
||||
load_explore_json_into_cache(job_metadata, form_data)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from jwt import encode
|
||||
from pytest import fixture, raises
|
||||
|
||||
from superset.async_events.async_query_manager import (
|
||||
AsyncQueryManager,
|
||||
AsyncQueryTokenException,
|
||||
)
|
||||
|
||||
JWT_TOKEN_SECRET = "some_secret"
|
||||
JWT_TOKEN_COOKIE_NAME = "superset_async_jwt"
|
||||
|
||||
|
||||
@fixture
|
||||
def async_query_manager():
|
||||
query_manager = AsyncQueryManager()
|
||||
query_manager._jwt_secret = JWT_TOKEN_SECRET
|
||||
query_manager._jwt_cookie_name = JWT_TOKEN_COOKIE_NAME
|
||||
|
||||
return query_manager
|
||||
|
||||
|
||||
def test_parse_channel_id_from_request(async_query_manager):
|
||||
encoded_token = encode(
|
||||
{"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256"
|
||||
)
|
||||
|
||||
request = Mock()
|
||||
request.cookies = {"superset_async_jwt": encoded_token}
|
||||
|
||||
assert (
|
||||
async_query_manager.parse_channel_id_from_request(request) == "test_channel_id"
|
||||
)
|
||||
|
||||
|
||||
def test_parse_channel_id_from_request_no_cookie(async_query_manager):
|
||||
request = Mock()
|
||||
request.cookies = {}
|
||||
|
||||
with raises(AsyncQueryTokenException):
|
||||
async_query_manager.parse_channel_id_from_request(request)
|
||||
|
||||
|
||||
def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
|
||||
request = Mock()
|
||||
request.cookies = {"superset_async_jwt": "bad_jwt"}
|
||||
|
||||
with raises(AsyncQueryTokenException):
|
||||
async_query_manager.parse_channel_id_from_request(request)
|
||||
Loading…
Reference in New Issue