fix: trino thread app missing full context (#29981)

This commit is contained in:
Daniel Vaz Gaspar 2024-08-22 18:01:06 +01:00 committed by GitHub
parent c049771a7f
commit 4d821f44ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 2 deletions

View File

@ -27,7 +27,7 @@ from typing import Any, TYPE_CHECKING
import numpy as np
import pandas as pd
import pyarrow as pa
from flask import current_app, Flask, g
from flask import ctx, current_app, Flask, g
from sqlalchemy import text
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
@ -227,12 +227,22 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_event = threading.Event()
def _execute(
results: dict[str, Any], event: threading.Event, app: Flask
results: dict[str, Any],
event: threading.Event,
app: Flask,
g_copy: ctx._AppCtxGlobals,
) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)
try:
# Flask contexts are local to the thread that handles the request.
# When you spawn a new thread, it does not inherit the contexts
# from the parent thread,
# meaning the g object and other context-bound variables are not
# accessible
with app.app_context():
for key, value in g_copy.__dict__.items():
setattr(g, key, value)
cls.execute(cursor, sql, query.database)
except Exception as ex: # pylint: disable=broad-except
results["error"] = ex
@ -245,6 +255,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_result,
execute_event,
current_app._get_current_object(), # pylint: disable=protected-access
g._get_current_object(), # pylint: disable=protected-access
),
)
execute_thread.start()

View File

@ -23,6 +23,7 @@ from unittest.mock import MagicMock, Mock, patch
import pandas as pd
import pytest
from flask import g, has_app_context
from pytest_mock import MockerFixture
from requests.exceptions import ConnectionError as RequestsConnectionError
from sqlalchemy import sql, text, types
@ -435,6 +436,33 @@ def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
)
def test_execute_with_cursor_app_context(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` still contains the current app context"""
from superset.db_engine_specs.trino import TrinoEngineSpec
mock_cursor = mocker.MagicMock()
mock_cursor.query_id = None
mock_query = mocker.MagicMock()
g.some_value = "some_value"
def _mock_execute(*args, **kwargs):
assert has_app_context()
assert g.some_value == "some_value"
with patch.object(TrinoEngineSpec, "execute", side_effect=_mock_execute):
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
def test_get_columns(mocker: MockerFixture):
"""Test that ROW columns are not expanded without expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec