fix: Chart cache-warmup task fails on Superset 4.0 (#28706)
This commit is contained in:
parent
d7547fc4ef
commit
0744abe87b
|
|
@ -29,6 +29,7 @@ from superset.models.core import Log
|
|||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.tags.models import Tag, TaggedObject
|
||||
from superset.tasks.utils import fetch_csrf_token
|
||||
from superset.utils import json
|
||||
from superset.utils.date_parser import parse_human_datetime
|
||||
from superset.utils.machine_auth import MachineAuthProvider
|
||||
|
|
@ -219,7 +220,10 @@ def fetch_url(data: str, headers: dict[str, str]) -> dict[str, str]:
|
|||
"""
|
||||
result = {}
|
||||
try:
|
||||
url = get_url_path("Superset.warm_up_cache")
|
||||
# Fetch CSRF token for API request
|
||||
headers.update(fetch_csrf_token(headers))
|
||||
|
||||
url = get_url_path("ChartRestApi.warm_up_cache")
|
||||
logger.info("Fetching %s with payload %s", url, data)
|
||||
req = request.Request(
|
||||
url, data=bytes(data, "utf-8"), headers=headers, method="PUT"
|
||||
|
|
|
|||
|
|
@ -17,12 +17,18 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
import logging
|
||||
from http.client import HTTPResponse
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from urllib import request
|
||||
|
||||
from celery.utils.log import get_task_logger
|
||||
from flask import current_app, g
|
||||
|
||||
from superset.tasks.exceptions import ExecutorNotFoundError
|
||||
from superset.tasks.types import ExecutorType
|
||||
from superset.utils import json
|
||||
from superset.utils.urls import get_url_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
|
@ -30,6 +36,10 @@ if TYPE_CHECKING:
|
|||
from superset.reports.models import ReportSchedule
|
||||
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
def get_executor(
|
||||
executor_types: list[ExecutorType],
|
||||
|
|
@ -92,3 +102,39 @@ def get_current_user() -> str | None:
|
|||
return user.username
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def fetch_csrf_token(
|
||||
headers: dict[str, str], session_cookie_name: str = "session"
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Fetches a CSRF token for API requests
|
||||
|
||||
:param headers: A map of headers to use in the request, including the session cookie
|
||||
:returns: A map of headers, including the session cookie and csrf token
|
||||
"""
|
||||
url = get_url_path("SecurityRestApi.csrf_token")
|
||||
logger.info("Fetching %s", url)
|
||||
req = request.Request(url, headers=headers, method="GET")
|
||||
response: HTTPResponse
|
||||
with request.urlopen(req, timeout=600) as response:
|
||||
body = response.read().decode("utf-8")
|
||||
session_cookie: Optional[str] = None
|
||||
cookie_headers = response.headers.get_all("set-cookie")
|
||||
if cookie_headers:
|
||||
for cookie in cookie_headers:
|
||||
cookie = cookie.split(";", 1)[0]
|
||||
name, value = cookie.split("=", 1)
|
||||
if name == session_cookie_name:
|
||||
session_cookie = value
|
||||
break
|
||||
|
||||
if response.status == 200:
|
||||
data = json.loads(body)
|
||||
res = {"X-CSRF-Token": data["result"]}
|
||||
if session_cookie is not None:
|
||||
res["Cookie"] = session_cookie
|
||||
return res
|
||||
|
||||
logger.error("Error fetching CSRF token, status code: %s", response.status)
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -29,9 +29,10 @@ from tests.integration_tests.test_app import app
|
|||
],
|
||||
ids=["Without trailing slash", "With trailing slash"],
|
||||
)
|
||||
@mock.patch("superset.tasks.cache.fetch_csrf_token")
|
||||
@mock.patch("superset.tasks.cache.request.Request")
|
||||
@mock.patch("superset.tasks.cache.request.urlopen")
|
||||
def test_fetch_url(mock_urlopen, mock_request_cls, base_url):
|
||||
def test_fetch_url(mock_urlopen, mock_request_cls, mock_fetch_csrf_token, base_url):
|
||||
from superset.tasks.cache import fetch_url
|
||||
|
||||
mock_request = mock.MagicMock()
|
||||
|
|
@ -40,18 +41,22 @@ def test_fetch_url(mock_urlopen, mock_request_cls, base_url):
|
|||
mock_urlopen.return_value = mock.MagicMock()
|
||||
mock_urlopen.return_value.code = 200
|
||||
|
||||
initial_headers = {"Cookie": "cookie", "key": "value"}
|
||||
csrf_headers = initial_headers | {"X-CSRF-Token": "csrf_token"}
|
||||
mock_fetch_csrf_token.return_value = csrf_headers
|
||||
|
||||
app.config["WEBDRIVER_BASEURL"] = base_url
|
||||
headers = {"key": "value"}
|
||||
data = "data"
|
||||
data_encoded = b"data"
|
||||
|
||||
result = fetch_url(data, headers)
|
||||
result = fetch_url(data, initial_headers)
|
||||
|
||||
assert data == result["success"]
|
||||
mock_fetch_csrf_token.assert_called_once_with(initial_headers)
|
||||
mock_request_cls.assert_called_once_with(
|
||||
"http://base-url/superset/warm_up_cache/",
|
||||
"http://base-url/api/v1/chart/warm_up_cache",
|
||||
data=data_encoded,
|
||||
headers=headers,
|
||||
headers=csrf_headers,
|
||||
method="PUT",
|
||||
)
|
||||
# assert the same Request object is used
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_url",
|
||||
[
|
||||
"http://base-url",
|
||||
"http://base-url/",
|
||||
],
|
||||
ids=["Without trailing slash", "With trailing slash"],
|
||||
)
|
||||
@mock.patch("superset.tasks.cache.request.Request")
|
||||
@mock.patch("superset.tasks.cache.request.urlopen")
|
||||
def test_fetch_csrf_token(mock_urlopen, mock_request_cls, base_url, app_context):
|
||||
from superset.tasks.utils import fetch_csrf_token
|
||||
|
||||
mock_request = mock.MagicMock()
|
||||
mock_request_cls.return_value = mock_request
|
||||
|
||||
mock_response = mock.MagicMock()
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = b'{"result": "csrf_token"}'
|
||||
mock_response.headers.get_all.return_value = [
|
||||
"session=new_session_cookie",
|
||||
"async-token=websocket_cookie",
|
||||
]
|
||||
|
||||
app.config["WEBDRIVER_BASEURL"] = base_url
|
||||
headers = {"Cookie": "original_session_cookie"}
|
||||
|
||||
result_headers = fetch_csrf_token(headers)
|
||||
|
||||
mock_request_cls.assert_called_with(
|
||||
"http://base-url/api/v1/security/csrf_token/",
|
||||
headers=headers,
|
||||
method="GET",
|
||||
)
|
||||
|
||||
assert result_headers["X-CSRF-Token"] == "csrf_token"
|
||||
assert result_headers["Cookie"] == "new_session_cookie"
|
||||
# assert the same Request object is used
|
||||
mock_urlopen.assert_called_once_with(mock_request, timeout=mock.ANY)
|
||||
Loading…
Reference in New Issue