fix: trino cursor (#25897)

This commit is contained in:
Beto Dealmeida 2023-11-08 07:38:38 -05:00 committed by GitHub
parent 06ffcd29e2
commit cdb18e04ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 13 deletions

View File

@ -184,7 +184,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
cls, cursor: Cursor, sql: str, query: Query, session: Session
) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
@ -193,34 +193,40 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
in another thread and invoke `handle_cursor` to poll for the query ID
to appear on the cursor in parallel.
"""
# Fetch the query ID beforehand, since it might fail inside the thread due to
# how the SQLAlchemy session is handled.
query_id = query.id
execute_result: dict[str, Any] = {}
execute_event = threading.Event()
def _execute(results: dict[str, Any]) -> None:
logger.debug("Query %d: Running query: %s", query.id, sql)
def _execute(results: dict[str, Any], event: threading.Event) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)
# Pass result / exception information back to the parent thread
try:
cls.execute(cursor, sql)
results["complete"] = True
except Exception as ex: # pylint: disable=broad-except
results["complete"] = True
results["error"] = ex
finally:
event.set()
execute_thread = threading.Thread(target=_execute, args=(execute_result,))
execute_thread = threading.Thread(
target=_execute,
args=(execute_result, execute_event),
)
execute_thread.start()
# Wait for a query ID to be available before handling the cursor, as
# it's required by that method; it may never become available on error.
while not cursor.query_id and not execute_result.get("complete"):
while not cursor.query_id and not execute_event.is_set():
time.sleep(0.1)
logger.debug("Query %d: Handling cursor", query.id)
logger.debug("Query %d: Handling cursor", query_id)
cls.handle_cursor(cursor, query, session)
# Block until the query completes; same behaviour as the client itself
logger.debug("Query %d: Waiting for query to complete", query.id)
while not execute_result.get("complete"):
time.sleep(0.5)
logger.debug("Query %d: Waiting for query to complete", query_id)
execute_event.wait()
# Unfortunately we'll mangle the stack trace due to the thread, but
# throwing the original exception allows mapping database errors as normal
@ -234,7 +240,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
session.commit()
@classmethod
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
def cancel_query(cls, cursor: Cursor, query: Query, cancel_query_id: str) -> bool:
"""
Cancel query in the underlying database.