fix: trino thread app missing full context (#29981)
This commit is contained in:
parent
c049771a7f
commit
4d821f44ae
|
|
@ -27,7 +27,7 @@ from typing import Any, TYPE_CHECKING
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from flask import current_app, Flask, g
|
||||
from flask import ctx, current_app, Flask, g
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
|
@ -227,12 +227,22 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
execute_event = threading.Event()
|
||||
|
||||
def _execute(
|
||||
results: dict[str, Any], event: threading.Event, app: Flask
|
||||
results: dict[str, Any],
|
||||
event: threading.Event,
|
||||
app: Flask,
|
||||
g_copy: ctx._AppCtxGlobals,
|
||||
) -> None:
|
||||
logger.debug("Query %d: Running query: %s", query_id, sql)
|
||||
|
||||
try:
|
||||
# Flask contexts are local to the thread that handles the request.
|
||||
# When you spawn a new thread, it does not inherit the contexts
|
||||
# from the parent thread,
|
||||
# meaning the g object and other context-bound variables are not
|
||||
# accessible
|
||||
with app.app_context():
|
||||
for key, value in g_copy.__dict__.items():
|
||||
setattr(g, key, value)
|
||||
cls.execute(cursor, sql, query.database)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
results["error"] = ex
|
||||
|
|
@ -245,6 +255,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
execute_result,
|
||||
execute_event,
|
||||
current_app._get_current_object(), # pylint: disable=protected-access
|
||||
g._get_current_object(), # pylint: disable=protected-access
|
||||
),
|
||||
)
|
||||
execute_thread.start()
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from unittest.mock import MagicMock, Mock, patch
|
|||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from flask import g, has_app_context
|
||||
from pytest_mock import MockerFixture
|
||||
from requests.exceptions import ConnectionError as RequestsConnectionError
|
||||
from sqlalchemy import sql, text, types
|
||||
|
|
@ -435,6 +436,33 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
|
|||
)
|
||||
|
||||
|
||||
def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
|
||||
"""Test that `execute_with_cursor` still contains the current app context"""
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
|
||||
mock_cursor = mocker.MagicMock()
|
||||
mock_cursor.query_id = None
|
||||
|
||||
mock_query = mocker.MagicMock()
|
||||
g.some_value = "some_value"
|
||||
|
||||
def _mock_execute(*args, **kwargs):
|
||||
assert has_app_context()
|
||||
assert g.some_value == "some_value"
|
||||
|
||||
with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
|
||||
with patch.dict(
|
||||
"superset.config.DISALLOWED_SQL_FUNCTIONS",
|
||||
{},
|
||||
clear=True,
|
||||
):
|
||||
TrinoEngineSpec.execute_with_cursor(
|
||||
cursor=mock_cursor,
|
||||
sql="SELECT 1 FROM foo",
|
||||
query=mock_query,
|
||||
)
|
||||
|
||||
|
||||
def test_get_columns(mocker: MockerFixture):
|
||||
"""Test that ROW columns are not expanded without expand_rows"""
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
|
|
|
|||
Loading…
Reference in New Issue