feat(trino): support early cancellation of queries (#22498)

This commit is contained in:
Ville Brofeldt 2022-12-24 06:31:46 +02:00 committed by GitHub
parent 7926a43aed
commit b6d39d194c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 231 additions and 74 deletions

View File

@ -17,7 +17,7 @@
* under the License.
*/
import shortid from 'shortid';
import { t, SupersetClient } from '@superset-ui/core';
import { SupersetClient, t } from '@superset-ui/core';
import invert from 'lodash/invert';
import mapKeys from 'lodash/mapKeys';
import { isFeatureEnabled, FeatureFlag } from 'src/featureFlags';
@ -229,11 +229,13 @@ export function startQuery(query) {
export function querySuccess(query, results) {
return function (dispatch) {
const sqlEditorId = results?.query?.sqlEditorId;
const sync =
sqlEditorId &&
!query.isDataPreview &&
isFeatureEnabled(FeatureFlag.SQLLAB_BACKEND_PERSISTENCE)
? SupersetClient.put({
endpoint: encodeURI(`/tabstateview/${results.query.sqlEditorId}`),
endpoint: encodeURI(`/tabstateview/${sqlEditorId}`),
postPayload: { latest_query_id: query.id },
})
: Promise.resolve();

View File

@ -30,6 +30,7 @@ import {
initialState,
queryId,
} from 'src/SqlLab/fixtures';
import { QueryState } from '@superset-ui/core';
const middlewares = [thunk];
const mockStore = configureMockStore(middlewares);
@ -502,6 +503,7 @@ describe('async actions', () => {
const results = {
data: mockBigNumber,
query: { sqlEditorId: 'abcd' },
status: QueryState.SUCCESS,
query_id: 'efgh',
};
fetchMock.get(fetchQueryEndpoint, JSON.stringify(results), {
@ -525,6 +527,35 @@ describe('async actions', () => {
expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(1);
});
});
it("doesn't update the tab state in the backend on stoppped query", () => {
expect.assertions(2);
const results = {
status: QueryState.STOPPED,
query_id: 'efgh',
};
fetchMock.get(fetchQueryEndpoint, JSON.stringify(results), {
overwriteRoutes: true,
});
const store = mockStore({});
const expectedActions = [
{
type: actions.REQUEST_QUERY_RESULTS,
query,
},
// missing below
{
type: actions.QUERY_SUCCESS,
query,
results,
},
];
return store.dispatch(actions.fetchQueryResults(query)).then(() => {
expect(store.getActions()).toEqual(expectedActions);
expect(fetchMock.calls(updateTabStateEndpoint)).toHaveLength(0);
});
});
});
describe('addQueryEditor', () => {

View File

@ -16,13 +16,13 @@
* specific language governing permissions and limitations
* under the License.
*/
import React, { useState, useEffect, useCallback } from 'react';
import React, { useCallback, useEffect, useState } from 'react';
import { useDispatch } from 'react-redux';
import ButtonGroup from 'src/components/ButtonGroup';
import Alert from 'src/components/Alert';
import Button from 'src/components/Button';
import shortid from 'shortid';
import { styled, t, QueryResponse } from '@superset-ui/core';
import { QueryResponse, QueryState, styled, t } from '@superset-ui/core';
import { usePrevious } from 'src/hooks/usePrevious';
import ErrorMessageWithStackTrace from 'src/components/ErrorMessage/ErrorMessageWithStackTrace';
import {
@ -43,9 +43,9 @@ import CopyToClipboard from 'src/components/CopyToClipboard';
import { addDangerToast } from 'src/components/MessageToasts/actions';
import { prepareCopyToClipboardTabularData } from 'src/utils/common';
import {
CtasEnum,
clearQueryResults,
addQueryEditor,
clearQueryResults,
CtasEnum,
fetchQueryResults,
reFetchQueryResults,
reRunQuery,
@ -387,8 +387,8 @@ const ResultSet = ({
let trackingUrl;
if (
query.trackingUrl &&
query.state !== 'success' &&
query.state !== 'fetching'
query.state !== QueryState.SUCCESS &&
query.state !== QueryState.FETCHING
) {
trackingUrl = (
<Button
@ -397,7 +397,9 @@ const ResultSet = ({
href={query.trackingUrl}
target="_blank"
>
{query.state === 'running' ? t('Track job') : t('See query details')}
{query.state === QueryState.RUNNING
? t('Track job')
: t('See query details')}
</Button>
);
}
@ -406,11 +408,11 @@ const ResultSet = ({
sql = <HighlightedSql sql={query.sql} />;
}
if (query.state === 'stopped') {
if (query.state === QueryState.STOPPED) {
return <Alert type="warning" message={t('Query was stopped')} />;
}
if (query.state === 'failed') {
if (query.state === QueryState.FAILED) {
return (
<ResultlessStyles>
<ErrorMessageWithStackTrace
@ -426,7 +428,7 @@ const ResultSet = ({
);
}
if (query.state === 'success' && query.ctas) {
if (query.state === QueryState.SUCCESS && query.ctas) {
const { tempSchema, tempTable } = query;
let object = 'Table';
if (query.ctas_method === CtasEnum.VIEW) {
@ -465,7 +467,7 @@ const ResultSet = ({
);
}
if (query.state === 'success' && query.results) {
if (query.state === QueryState.SUCCESS && query.results) {
const { results } = query;
// Accounts for offset needed for height of ResultSetRowsReturned component if !limitReached
const rowMessageHeight = !limitReached ? 32 : 0;
@ -508,7 +510,7 @@ const ResultSet = ({
}
}
if (query.cached || (query.state === 'success' && !query.results)) {
if (query.cached || (query.state === QueryState.SUCCESS && !query.results)) {
if (query.isDataPreview) {
return (
<Button

View File

@ -53,7 +53,7 @@ const SqlEditorTabHeader: React.FC<Props> = ({ queryEditor }) => {
}),
shallowEqual,
);
const queryStatus = useSelector<SqlLabRootState, QueryState>(
const queryState = useSelector<SqlLabRootState, QueryState>(
({ sqlLab }) => sqlLab.queries[qe.latestQueryId || '']?.state || '',
);
const dispatch = useDispatch();
@ -139,7 +139,7 @@ const SqlEditorTabHeader: React.FC<Props> = ({ queryEditor }) => {
</Menu>
}
/>
<TabTitle>{qe.name}</TabTitle> <TabStatusIcon tabState={queryStatus} />{' '}
<TabTitle>{qe.name}</TabTitle> <TabStatusIcon tabState={queryState} />{' '}
</TabTitleWrapper>
);
};

View File

@ -16,8 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
import { t } from '@superset-ui/core';
import { QueryState, t } from '@superset-ui/core';
import getInitialState from './getInitialState';
import * as actions from '../actions/sqlLab';
import { now } from '../../utils/dates';
@ -391,7 +390,7 @@ export default function sqlLabReducer(state = {}, action) {
},
[actions.STOP_QUERY]() {
return alterInObject(state, 'queries', action.query, {
state: 'stopped',
state: QueryState.STOPPED,
results: [],
});
},
@ -405,12 +404,16 @@ export default function sqlLabReducer(state = {}, action) {
},
[actions.REQUEST_QUERY_RESULTS]() {
return alterInObject(state, 'queries', action.query, {
state: 'fetching',
state: QueryState.FETCHING,
});
},
[actions.QUERY_SUCCESS]() {
// prevent race condition were query succeeds shortly after being canceled
if (action.query.state === 'stopped') {
// prevent race condition where query succeeds shortly after being canceled
// or the final result was unsuccessful
if (
action.query.state === QueryState.STOPPED ||
action.results.status !== QueryState.SUCCESS
) {
return state;
}
const alts = {
@ -418,7 +421,7 @@ export default function sqlLabReducer(state = {}, action) {
progress: 100,
results: action.results,
rows: action?.results?.query?.rows || 0,
state: 'success',
state: QueryState.SUCCESS,
limitingFactor: action?.results?.query?.limitingFactor,
tempSchema: action?.results?.query?.tempSchema,
tempTable: action?.results?.query?.tempTable,
@ -434,11 +437,11 @@ export default function sqlLabReducer(state = {}, action) {
return alterInObject(state, 'queries', action.query, alts);
},
[actions.QUERY_FAILED]() {
if (action.query.state === 'stopped') {
if (action.query.state === QueryState.STOPPED) {
return state;
}
const alts = {
state: 'failed',
state: QueryState.FAILED,
errors: action.errors,
errorMessage: action.msg,
endDttm: now(),
@ -723,8 +726,8 @@ export default function sqlLabReducer(state = {}, action) {
Object.entries(action.alteredQueries).forEach(([id, changedQuery]) => {
if (
!state.queries.hasOwnProperty(id) ||
(state.queries[id].state !== 'stopped' &&
state.queries[id].state !== 'failed')
(state.queries[id].state !== QueryState.STOPPED &&
state.queries[id].state !== QueryState.FAILED)
) {
if (changedQuery.changedOn > queriesLastUpdate) {
queriesLastUpdate = changedQuery.changedOn;
@ -738,8 +741,8 @@ export default function sqlLabReducer(state = {}, action) {
// because of async behavior, sql lab may still poll a couple of seconds
// when it started fetching or finished rendering results
state:
currentState === 'success' &&
['fetching', 'success'].includes(prevState)
currentState === QueryState.SUCCESS &&
[QueryState.FETCHING, QueryState.SUCCESS].includes(prevState)
? prevState
: currentState,
};

View File

@ -33,6 +33,7 @@ import ListView from 'src/components/ListView';
import Filters from 'src/components/ListView/Filters';
import SyntaxHighlighter from 'react-syntax-highlighter/dist/cjs/light';
import SubMenu from 'src/views/components/SubMenu';
import { QueryState } from '@superset-ui/core';
// store needed for withToasts
const mockStore = configureStore([thunk]);
@ -54,7 +55,7 @@ const mockQueries: QueryObject[] = [...new Array(3)].map((_, i) => ({
{ schema: 'foo', table: 'table' },
{ schema: 'bar', table: 'table_2' },
],
status: 'success',
status: QueryState.SUCCESS,
tab_name: 'Main Tab',
user: {
first_name: 'cool',

View File

@ -17,7 +17,13 @@
* under the License.
*/
import React, { useMemo, useState, useCallback, ReactElement } from 'react';
import { SupersetClient, t, styled, useTheme } from '@superset-ui/core';
import {
QueryState,
styled,
SupersetClient,
t,
useTheme,
} from '@superset-ui/core';
import moment from 'moment';
import {
createFetchRelated,
@ -127,7 +133,13 @@ function QueryList({ addDangerToast }: QueryListProps) {
row: {
original: { status },
},
}: any) => {
}: {
row: {
original: {
status: QueryState;
};
};
}) => {
const statusConfig: {
name: ReactElement | null;
label: string;
@ -135,33 +147,39 @@ function QueryList({ addDangerToast }: QueryListProps) {
name: null,
label: '',
};
if (status === 'success') {
if (status === QueryState.SUCCESS) {
statusConfig.name = (
<Icons.Check iconColor={theme.colors.success.base} />
);
statusConfig.label = t('Success');
} else if (status === 'failed' || status === 'stopped') {
} else if (
status === QueryState.FAILED ||
status === QueryState.STOPPED
) {
statusConfig.name = (
<Icons.XSmall
iconColor={
status === 'failed'
status === QueryState.FAILED
? theme.colors.error.base
: theme.colors.grayscale.base
}
/>
);
statusConfig.label = t('Failed');
} else if (status === 'running') {
} else if (status === QueryState.RUNNING) {
statusConfig.name = (
<Icons.Running iconColor={theme.colors.primary.base} />
);
statusConfig.label = t('Running');
} else if (status === 'timed_out') {
} else if (status === QueryState.TIMED_OUT) {
statusConfig.name = (
<Icons.Offline iconColor={theme.colors.grayscale.light1} />
);
statusConfig.label = t('Offline');
} else if (status === 'scheduled' || status === 'pending') {
} else if (
status === QueryState.SCHEDULED ||
status === QueryState.PENDING
) {
statusConfig.name = (
<Icons.Queued iconColor={theme.colors.grayscale.base} />
);

View File

@ -27,6 +27,7 @@ import QueryPreviewModal from 'src/views/CRUD/data/query/QueryPreviewModal';
import { QueryObject } from 'src/views/CRUD/types';
import SyntaxHighlighter from 'react-syntax-highlighter/dist/cjs/light';
import { act } from 'react-dom/test-utils';
import { QueryState } from '@superset-ui/core';
// store needed for withToasts
const mockStore = configureStore([thunk]);
@ -46,7 +47,7 @@ const mockQueries: QueryObject[] = [...new Array(3)].map((_, i) => ({
{ schema: 'foo', table: 'table' },
{ schema: 'bar', table: 'table_2' },
],
status: 'success',
status: QueryState.SUCCESS,
tab_name: 'Main Tab',
user: {
first_name: 'cool',

View File

@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
import { QueryState } from '@superset-ui/core';
import { User } from 'src/types/bootstrapTypes';
import Database from 'src/types/Database';
import Owner from 'src/types/Owner';
@ -94,14 +95,7 @@ export interface QueryObject {
sql: string;
executed_sql: string | null;
sql_tables?: { catalog?: string; schema: string; table: string }[];
status:
| 'success'
| 'failed'
| 'stopped'
| 'running'
| 'timed_out'
| 'scheduled'
| 'pending';
status: QueryState;
tab_name: string;
user: {
first_name: string;

View File

@ -34,6 +34,9 @@ PASSWORD_MASK = "X" * 10
NO_TIME_RANGE = "No filter"
QUERY_CANCEL_KEY = "cancel_query"
QUERY_EARLY_CANCEL_KEY = "early_cancel_query"
class RouteMethod: # pylint: disable=too-few-public-methods
"""

View File

@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
from __future__ import annotations
import json
import logging
import re
@ -478,7 +481,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_engine(
cls,
database: "Database",
database: Database,
schema: Optional[str] = None,
source: Optional[utils.QuerySource] = None,
) -> ContextManager[Engine]:
@ -733,7 +736,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def extra_table_metadata( # pylint: disable=unused-argument
cls,
database: "Database",
database: Database,
table_name: str,
schema_name: Optional[str],
) -> Dict[str, Any]:
@ -750,7 +753,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def apply_limit_to_sql(
cls, sql: str, limit: int, database: "Database", force: bool = False
cls, sql: str, limit: int, database: Database, force: bool = False
) -> str:
"""
Alters the SQL statement to apply a LIMIT clause
@ -892,7 +895,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def df_to_sql(
cls,
database: "Database",
database: Database,
table: Table,
df: pd.DataFrame,
to_sql_kwargs: Dict[str, Any],
@ -939,7 +942,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@classmethod
def handle_cursor(cls, cursor: Any, query: "Query", session: Session) -> None:
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
"""Handle a live cursor between the execute and fetchall calls
The flow works without this method doing anything, but it allows
@ -1031,7 +1034,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_table_names( # pylint: disable=unused-argument
cls,
database: "Database",
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> Set[str]:
@ -1059,7 +1062,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_view_names( # pylint: disable=unused-argument
cls,
database: "Database",
database: Database,
inspector: Inspector,
schema: Optional[str],
) -> Set[str]:
@ -1125,7 +1128,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_metrics( # pylint: disable=unused-argument
cls,
database: "Database",
database: Database,
inspector: Inspector,
table_name: str,
schema: Optional[str],
@ -1147,7 +1150,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
table_name: str,
schema: Optional[str],
database: "Database",
database: Database,
query: Select,
columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]:
@ -1172,7 +1175,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def select_star( # pylint: disable=too-many-arguments,too-many-locals
cls,
database: "Database",
database: Database,
table_name: str,
engine: Engine,
schema: Optional[str] = None,
@ -1251,7 +1254,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
raise Exception("Database does not support cost estimation")
@classmethod
def process_statement(cls, statement: str, database: "Database") -> str:
def process_statement(cls, statement: str, database: Database) -> str:
"""
Process a SQL statement by stripping and mutating it.
@ -1275,7 +1278,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def estimate_query_cost(
cls,
database: "Database",
database: Database,
schema: str,
sql: str,
source: Optional[utils.QuerySource] = None,
@ -1471,7 +1474,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_function_names( # pylint: disable=unused-argument
cls,
database: "Database",
database: Database,
) -> List[str]:
"""
Get a list of function names that are able to be called on the database.
@ -1496,7 +1499,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@staticmethod
def mutate_db_for_connection_test( # pylint: disable=unused-argument
database: "Database",
database: Database,
) -> None:
"""
Some databases require passing additional parameters for validating database
@ -1508,7 +1511,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
def get_extra_params(database: Database) -> Dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.
@ -1527,7 +1530,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@staticmethod
def update_params_from_encrypted_extra( # pylint: disable=invalid-name
database: "Database", params: Dict[str, Any]
database: Database, params: Dict[str, Any]
) -> None:
"""
Some databases require some sensitive information which do not conform to
@ -1589,11 +1592,22 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
)
return None
# pylint: disable=unused-argument
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
"""
Some databases may acquire the query cancelation id after the query
cancelation request has been received. For those cases, the db engine spec
can record the cancelation intent so that the query can either be stopped
prior to execution, or canceled once the query id is acquired.
"""
return None
@classmethod
def has_implicit_cancel(cls) -> bool:
"""
Return True if the live cursor handles the implicit cancelation of the query,
False otherise.
False otherwise.
:return: Whether the live cursor implicitly cancels the query
:see: handle_cursor
@ -1605,7 +1619,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_cancel_query_id( # pylint: disable=unused-argument
cls,
cursor: Any,
query: "Query",
query: Query,
) -> Optional[str]:
"""
Select identifiers from the database engine that uniquely identifies the
@ -1623,7 +1637,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def cancel_query( # pylint: disable=unused-argument
cls,
cursor: Any,
query: "Query",
query: Query,
cancel_query_id: str,
) -> bool:
"""

View File

@ -559,7 +559,7 @@ class HiveEngineSpec(PrestoEngineSpec):
def has_implicit_cancel(cls) -> bool:
"""
Return True if the live cursor handles the implicit cancelation of the query,
False otherise.
False otherwise.
:return: Whether the live cursor implicitly cancels the query
:see: handle_cursor

View File

@ -1307,7 +1307,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
def has_implicit_cancel(cls) -> bool:
"""
Return True if the live cursor handles the implicit cancelation of the query,
False otherise.
False otherwise.
:return: Whether the live cursor implicitly cancels the query
:see: handle_cursor

View File

@ -26,7 +26,7 @@ from flask import current_app
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from superset.constants import USER_AGENT
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
@ -181,11 +181,30 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
query.tracking_url = tracking_url
# Adds the executed query id to the extra payload so the query can be cancelled
query.set_extra_json_key("cancel_query", cursor.stats["queryId"])
query.set_extra_json_key(
key=QUERY_CANCEL_KEY,
value=(cancel_query_id := cursor.stats["queryId"]),
)
session.commit()
# if query cancelation was requested prior to the handle_cursor call, but
# the query was still executed, trigger the actual query cancelation now
if query.extra.get(QUERY_EARLY_CANCEL_KEY):
cls.cancel_query(
cursor=cursor,
query=query,
cancel_query_id=cancel_query_id,
)
super().handle_cursor(cursor=cursor, query=query, session=session)
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
if QUERY_CANCEL_KEY not in query.extra:
query.set_extra_json_key(QUERY_EARLY_CANCEL_KEY, True)
session.commit()
@classmethod
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
"""

View File

@ -33,12 +33,14 @@ from sqlalchemy.orm import Session
from superset import (
app,
db,
is_feature_enabled,
results_backend,
results_backend_use_msgpack,
security_manager,
)
from superset.common.db_query_status import QueryStatus
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@ -69,7 +71,6 @@ SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"]
SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
log_query = config["QUERY_LOGGER"]
logger = logging.getLogger(__name__)
cancel_query_key = "cancel_query"
class SqlLabException(Exception):
@ -473,7 +474,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca
cursor = conn.cursor()
cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query)
if cancel_query_id is not None:
query.set_extra_json_key(cancel_query_key, cancel_query_id)
query.set_extra_json_key(QUERY_CANCEL_KEY, cancel_query_id)
session.commit()
statement_count = len(statements)
for i, statement in enumerate(statements):
@ -613,7 +614,7 @@ def cancel_query(query: Query) -> bool:
"""
Cancel a running query.
Note some engines implicitly handle the cancelation of a query and thus no expliicit
Note some engines implicitly handle the cancelation of a query and thus no explicit
action is required.
:param query: Query to cancel
@ -623,7 +624,16 @@ def cancel_query(query: Query) -> bool:
if query.database.db_engine_spec.has_implicit_cancel():
return True
cancel_query_id = query.extra.get(cancel_query_key)
# Some databases may need to make preparations for query cancellation
query.database.db_engine_spec.prepare_cancel_query(query, db.session)
if query.extra.get(QUERY_EARLY_CANCEL_KEY):
# Query has been cancelled prior to being able to set the cancel key.
# This can happen if the query cancellation key can only be acquired after the
# query has been executed
return True
cancel_query_id = query.extra.get(QUERY_CANCEL_KEY)
if cancel_query_id is None:
return False

View File

@ -15,8 +15,15 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import json
from typing import Any, Dict
from unittest import mock
import pytest
from pytest_mock import MockerFixture
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY
@mock.patch("sqlalchemy.engine.Engine.connect")
def test_cancel_query_success(engine_mock: mock.Mock) -> None:
@ -36,3 +43,55 @@ def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
query = Query()
cursor_mock = engine_mock.raiseError.side_effect = Exception()
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
@pytest.mark.parametrize(
"initial_extra,final_extra",
[
({}, {QUERY_EARLY_CANCEL_KEY: True}),
({QUERY_CANCEL_KEY: "my_key"}, {QUERY_CANCEL_KEY: "my_key"}),
],
)
def test_prepare_cancel_query(
initial_extra: Dict[str, Any],
final_extra: Dict[str, Any],
mocker: MockerFixture,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
session_mock = mocker.MagicMock()
query = Query(extra_json=json.dumps(initial_extra))
TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
assert query.extra == final_extra
@pytest.mark.parametrize("cancel_early", [True, False])
@mock.patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
@mock.patch("sqlalchemy.engine.Engine.connect")
def test_handle_cursor_early_cancel(
engine_mock: mock.Mock,
cancel_query_mock: mock.Mock,
cancel_early: bool,
mocker: MockerFixture,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
from superset.models.sql_lab import Query
query_id = "myQueryId"
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.stats = {"queryId": query_id}
session_mock = mocker.MagicMock()
query = Query()
if cancel_early:
TrinoEngineSpec.prepare_cancel_query(query=query, session=session_mock)
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query, session=session_mock)
if cancel_early:
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
else:
assert cancel_query_mock.call_args is None