chore(async): Initial Refactoring of Global Async Queries (#25466)

This commit is contained in:
Craig Rueda 2023-10-02 17:22:07 -07:00 committed by GitHub
parent 36ed617090
commit db7f5fed31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 176 additions and 44 deletions

View File

@ -38,6 +38,7 @@ type ConfigType = {
redisStreamReadBlockMs: number;
jwtSecret: string;
jwtCookieName: string;
jwtChannelIdKey: string;
socketResponseTimeoutMs: number;
pingSocketsIntervalMs: number;
gcChannelsIntervalMs: number;
@ -54,6 +55,7 @@ function defaultConfig(): ConfigType {
redisStreamReadBlockMs: 5000,
jwtSecret: '',
jwtCookieName: 'async-token',
jwtChannelIdKey: 'channel',
socketResponseTimeoutMs: 60 * 1000,
pingSocketsIntervalMs: 20 * 1000,
gcChannelsIntervalMs: 120 * 1000,

View File

@ -53,7 +53,7 @@ interface EventValue {
result_url?: string;
}
interface JwtPayload {
channel: string;
[key: string]: string;
}
interface FetchRangeFromStreamParams {
sessionId: string;
@ -253,14 +253,20 @@ export const processStreamResults = (results: StreamResult[]): void => {
/**
* Verify and parse a JWT cookie from an HTTP request.
* Returns the JWT payload or throws an error on invalid token.
* Returns the channelId from the JWT payload found in the cookie
* configured via 'jwtCookieName' in the config.
*/
const getJwtPayload = (request: http.IncomingMessage): JwtPayload => {
const readChannelId = (request: http.IncomingMessage): string => {
const cookies = cookie.parse(request.headers.cookie || '');
const token = cookies[opts.jwtCookieName];
if (!token) throw new Error('JWT not present');
return jwt.verify(token, opts.jwtSecret) as JwtPayload;
const jwtPayload = jwt.verify(token, opts.jwtSecret) as JwtPayload;
const channelId = jwtPayload[opts.jwtChannelIdKey];
if (!channelId) throw new Error('Channel ID not present in JWT');
return channelId;
};
/**
@ -286,8 +292,7 @@ export const incrementId = (id: string): string => {
* WebSocket `connection` event handler, called via wss
*/
export const wsConnection = (ws: WebSocket, request: http.IncomingMessage) => {
const jwtPayload: JwtPayload = getJwtPayload(request);
const channel: string = jwtPayload.channel;
const channel: string = readChannelId(request);
const socketInstance: SocketInstance = { ws, channel, pongTs: Date.now() };
// add this ws instance to the internal registry
@ -351,8 +356,7 @@ export const httpUpgrade = (
head: Buffer,
) => {
try {
const jwtPayload: JwtPayload = getJwtPayload(request);
if (!jwtPayload.channel) throw new Error('Channel ID not present');
readChannelId(request);
} catch (err) {
// JWT invalid, do not establish a WebSocket connection
logger.error(err);

View File

@ -88,9 +88,9 @@ class AsyncEventsRestApi(BaseSupersetApi):
$ref: '#/components/responses/500'
"""
try:
async_channel_id = async_query_manager.parse_jwt_from_request(request)[
"channel"
]
async_channel_id = async_query_manager.parse_channel_id_from_request(
request
)
last_event_id = request.args.get("last_id")
events = async_query_manager.read_events(async_channel_id, last_event_id)

View File

@ -82,6 +82,9 @@ class AsyncQueryManager:
self._jwt_cookie_domain: Optional[str]
self._jwt_cookie_samesite: Optional[Literal["None", "Lax", "Strict"]] = None
self._jwt_secret: str
self._load_chart_data_into_cache_job: Any = None
# pylint: disable=invalid-name
self._load_explore_json_into_cache_job: Any = None
def init_app(self, app: Flask) -> None:
config = app.config
@ -115,6 +118,19 @@ class AsyncQueryManager:
self._jwt_cookie_domain = config["GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN"]
self._jwt_secret = config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]
if config["GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS"]:
self.register_request_handlers(app)
# pylint: disable=import-outside-toplevel
from superset.tasks.async_queries import (
load_chart_data_into_cache,
load_explore_json_into_cache,
)
self._load_chart_data_into_cache_job = load_chart_data_into_cache
self._load_explore_json_into_cache_job = load_explore_json_into_cache
def register_request_handlers(self, app: Flask) -> None:
@app.after_request
def validate_session(response: Response) -> Response:
user_id = get_user_id()
@ -149,13 +165,13 @@ class AsyncQueryManager:
return response
def parse_jwt_from_request(self, req: Request) -> dict[str, Any]:
def parse_channel_id_from_request(self, req: Request) -> str:
token = req.cookies.get(self._jwt_cookie_name)
if not token:
raise AsyncQueryTokenException("Token not preset")
try:
return jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
return jwt.decode(token, self._jwt_secret, algorithms=["HS256"])["channel"]
except Exception as ex:
logger.warning("Parse jwt failed", exc_info=True)
raise AsyncQueryTokenException("Failed to parse token") from ex
@ -166,6 +182,31 @@ class AsyncQueryManager:
channel_id, job_id, user_id, status=self.STATUS_PENDING
)
# pylint: disable=too-many-arguments
def submit_explore_json_job(
self,
channel_id: str,
form_data: dict[str, Any],
response_type: str,
force: Optional[bool] = False,
user_id: Optional[int] = None,
) -> dict[str, Any]:
job_metadata = self.init_job(channel_id, user_id)
self._load_explore_json_into_cache_job.delay(
job_metadata,
form_data,
response_type,
force,
)
return job_metadata
def submit_chart_data_job(
self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int]
) -> dict[str, Any]:
job_metadata = self.init_job(channel_id, user_id)
self._load_chart_data_into_cache_job.delay(job_metadata, form_data)
return job_metadata
def read_events(
self, channel: str, last_id: Optional[str]
) -> list[Optional[dict[str, Any]]]:

View File

@ -20,7 +20,6 @@ from typing import Any, Optional
from flask import Request
from superset.extensions import async_query_manager
from superset.tasks.async_queries import load_chart_data_into_cache
logger = logging.getLogger(__name__)
@ -29,10 +28,11 @@ class CreateAsyncChartDataJobCommand:
_async_channel_id: str
def validate(self, request: Request) -> None:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]
self._async_channel_id = async_query_manager.parse_channel_id_from_request(
request
)
def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata
return async_query_manager.submit_chart_data_job(
self._async_channel_id, form_data, user_id
)

View File

@ -26,8 +26,8 @@ feature_flags = config.DEFAULT_FEATURE_FLAGS.copy()
feature_flags.update(config.FEATURE_FLAGS)
feature_flags_func = config.GET_FEATURE_FLAGS_FUNC
if feature_flags_func:
# pylint: disable=not-callable
try:
# pylint: disable=not-callable
feature_flags = feature_flags_func(feature_flags)
except Exception: # pylint: disable=broad-except
# bypass any feature flags that depend on context

View File

@ -1524,6 +1524,7 @@ GLOBAL_ASYNC_QUERIES_REDIS_CONFIG = {
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX = "async-events-"
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
GLOBAL_ASYNC_QUERIES_REGISTER_REQUEST_HANDLERS = True
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (

View File

@ -74,7 +74,6 @@ from superset.models.sql_lab import Query
from superset.models.user_attributes import UserAttribute
from superset.sqllab.utils import bootstrap_sqllab_data
from superset.superset_typing import FlaskResponse
from superset.tasks.async_queries import load_explore_json_into_cache
from superset.utils import core as utils
from superset.utils.cache import etag_cache
from superset.utils.core import (
@ -320,14 +319,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# at which point they will call the /explore_json/data/<cache_key>
# endpoint to retrieve the results.
try:
async_channel_id = async_query_manager.parse_jwt_from_request(
request
)["channel"]
job_metadata = async_query_manager.init_job(
async_channel_id, get_user_id()
async_channel_id = (
async_query_manager.parse_channel_id_from_request(request)
)
load_explore_json_into_cache.delay(
job_metadata, form_data, response_type, force
job_metadata = async_query_manager.submit_explore_json_job(
async_channel_id, form_data, response_type, force, get_user_id()
)
except AsyncQueryTokenException:
return json_error_response("Not authorized", 401)

View File

@ -20,18 +20,11 @@ from uuid import uuid4
import pytest
from celery.exceptions import SoftTimeLimitExceeded
from flask import g
from superset.charts.commands.exceptions import ChartDataQueryFailedError
from superset.charts.data.commands.get_data_command import ChartDataCommand
from superset.exceptions import SupersetException
from superset.extensions import async_query_manager, security_manager
from superset.tasks import async_queries
from superset.tasks.async_queries import (
load_chart_data_into_cache,
load_explore_json_into_cache,
)
from superset.utils.core import get_user_id
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@ -43,10 +36,14 @@ from tests.integration_tests.test_app import app
class TestAsyncQueries(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.usefixtures(
"load_birth_names_data", "load_birth_names_dashboard_with_slices"
)
@mock.patch.object(async_query_manager, "update_job")
@mock.patch.object(async_queries, "set_form_data")
@mock.patch("superset.tasks.async_queries.set_form_data")
def test_load_chart_data_into_cache(self, mock_set_form_data, mock_update_job):
from superset.tasks.async_queries import load_chart_data_into_cache
app._got_first_request = False
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
@ -70,6 +67,8 @@ class TestAsyncQueries(SupersetTestCase):
)
@mock.patch.object(async_query_manager, "update_job")
def test_load_chart_data_into_cache_error(self, mock_update_job, mock_run_command):
from superset.tasks.async_queries import load_chart_data_into_cache
app._got_first_request = False
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
@ -93,6 +92,8 @@ class TestAsyncQueries(SupersetTestCase):
def test_soft_timeout_load_chart_data_into_cache(
self, mock_update_job, mock_run_command
):
from superset.tasks.async_queries import load_chart_data_into_cache
app._got_first_request = False
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
@ -107,9 +108,8 @@ class TestAsyncQueries(SupersetTestCase):
errors = ["A timeout occurred while loading chart data"]
with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries,
"set_form_data",
with mock.patch(
"superset.tasks.async_queries.set_form_data"
) as set_form_data:
set_form_data.side_effect = SoftTimeLimitExceeded()
load_chart_data_into_cache(job_metadata, form_data)
@ -118,6 +118,8 @@ class TestAsyncQueries(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job):
from superset.tasks.async_queries import load_explore_json_into_cache
app._got_first_request = False
async_query_manager.init_app(app)
table = self.get_table(name="birth_names")
@ -146,10 +148,12 @@ class TestAsyncQueries(SupersetTestCase):
)
@mock.patch.object(async_query_manager, "update_job")
@mock.patch.object(async_queries, "set_form_data")
@mock.patch("superset.tasks.async_queries.set_form_data")
def test_load_explore_json_into_cache_error(
self, mock_set_form_data, mock_update_job
):
from superset.tasks.async_queries import load_explore_json_into_cache
app._got_first_request = False
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
@ -174,6 +178,8 @@ class TestAsyncQueries(SupersetTestCase):
def test_soft_timeout_load_explore_json_into_cache(
self, mock_update_job, mock_run_command
):
from superset.tasks.async_queries import load_explore_json_into_cache
app._got_first_request = False
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
@ -188,9 +194,8 @@ class TestAsyncQueries(SupersetTestCase):
errors = ["A timeout occurred while loading explore json, error"]
with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries,
"set_form_data",
with mock.patch(
"superset.tasks.async_queries.set_form_data"
) as set_form_data:
set_form_data.side_effect = SoftTimeLimitExceeded()
load_explore_json_into_cache(job_metadata, form_data)

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import Mock
from jwt import encode
from pytest import fixture, raises
from superset.async_events.async_query_manager import (
AsyncQueryManager,
AsyncQueryTokenException,
)
JWT_TOKEN_SECRET = "some_secret"
JWT_TOKEN_COOKIE_NAME = "superset_async_jwt"
@fixture
def async_query_manager():
query_manager = AsyncQueryManager()
query_manager._jwt_secret = JWT_TOKEN_SECRET
query_manager._jwt_cookie_name = JWT_TOKEN_COOKIE_NAME
return query_manager
def test_parse_channel_id_from_request(async_query_manager):
encoded_token = encode(
{"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256"
)
request = Mock()
request.cookies = {"superset_async_jwt": encoded_token}
assert (
async_query_manager.parse_channel_id_from_request(request) == "test_channel_id"
)
def test_parse_channel_id_from_request_no_cookie(async_query_manager):
request = Mock()
request.cookies = {}
with raises(AsyncQueryTokenException):
async_query_manager.parse_channel_id_from_request(request)
def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
request = Mock()
request.cookies = {"superset_async_jwt": "bad_jwt"}
with raises(AsyncQueryTokenException):
async_query_manager.parse_channel_id_from_request(request)