feat(sql lab): display presto and trino tracking url (#20799)

This commit is contained in:
Jesse Yang 2022-07-26 20:20:08 -07:00 committed by GitHub
parent 35184b2994
commit 77db0651d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 340 additions and 84 deletions

View File

@ -25,6 +25,7 @@ assists people when migrating to a new version.
## Next ## Next
- [20606](https://github.com/apache/superset/pull/20606): When user clicks on chart title or "Edit chart" button in Dashboard page, Explore opens in the same tab. Clicking while holding cmd/ctrl opens Explore in a new tab. To bring back the old behaviour (always opening Explore in a new tab), flip feature flag `DASHBOARD_EDIT_CHART_IN_NEW_TAB` to `True`. - [20606](https://github.com/apache/superset/pull/20606): When user clicks on chart title or "Edit chart" button in Dashboard page, Explore opens in the same tab. Clicking while holding cmd/ctrl opens Explore in a new tab. To bring back the old behaviour (always opening Explore in a new tab), flip feature flag `DASHBOARD_EDIT_CHART_IN_NEW_TAB` to `True`.
- [20799](https://github.com/apache/superset/pull/20799): Presto and Trino engine will now display tracking URL for running queries in SQL Lab. If for some reason you don't want to show the tracking URL (for example, when your data warehouse hasn't enable access for to Presto or Trino UI), update `TRACKING_URL_TRANSFORMER` in `config.py` to return `None`.
### Breaking Changes ### Breaking Changes

View File

@ -54,6 +54,20 @@ You can run unit tests found in './tests/unit_tests' for example with pytest. It
pytest ./link_to_test.py pytest ./link_to_test.py
``` ```
#### Testing with local Presto connections
If you happen to change db engine spec for Presto/Trino, you can run a local Presto cluster with Docker:
```bash
docker run -p 15433:15433 starburstdata/presto:350-e.6
```
Then update `SUPERSET__SQLALCHEMY_EXAMPLES_URI` to point to local Presto cluster:
```bash
export SUPERSET__SQLALCHEMY_EXAMPLES_URI=presto://localhost:15433/memory/default
```
### Frontend Testing ### Frontend Testing
We use [Jest](https://jestjs.io/) and [Enzyme](https://airbnb.io/enzyme/) to test TypeScript/JavaScript. Tests can be run with: We use [Jest](https://jestjs.io/) and [Enzyme](https://airbnb.io/enzyme/) to test TypeScript/JavaScript. Tests can be run with:

View File

@ -109,6 +109,9 @@ const ResultSetButtons = styled.div`
const ResultSetErrorMessage = styled.div` const ResultSetErrorMessage = styled.div`
padding-top: ${({ theme }) => 4 * theme.gridUnit}px; padding-top: ${({ theme }) => 4 * theme.gridUnit}px;
.sql-result-track-job {
margin-top: ${({ theme }) => 2 * theme.gridUnit}px;
}
`; `;
export default class ResultSet extends React.PureComponent< export default class ResultSet extends React.PureComponent<
@ -417,6 +420,19 @@ export default class ResultSet extends React.PureComponent<
if (this.props.database && this.props.database.explore_database_id) { if (this.props.database && this.props.database.explore_database_id) {
exploreDBId = this.props.database.explore_database_id; exploreDBId = this.props.database.explore_database_id;
} }
let trackingUrl;
if (query.trackingUrl) {
trackingUrl = (
<Button
className="sql-result-track-job"
buttonSize="small"
href={query.trackingUrl}
target="_blank"
>
{query.state === 'running' ? t('Track job') : t('See query details')}
</Button>
);
}
if (this.props.showSql) sql = <HighlightedSql sql={query.sql} />; if (this.props.showSql) sql = <HighlightedSql sql={query.sql} />;
@ -434,6 +450,7 @@ export default class ResultSet extends React.PureComponent<
link={query.link} link={query.link}
source="sqllab" source="sqllab"
/> />
{trackingUrl}
</ResultSetErrorMessage> </ResultSetErrorMessage>
); );
} }
@ -550,7 +567,6 @@ export default class ResultSet extends React.PureComponent<
); );
} }
} }
let trackingUrl;
let progressBar; let progressBar;
if (query.progress > 0) { if (query.progress > 0) {
progressBar = ( progressBar = (
@ -560,16 +576,6 @@ export default class ResultSet extends React.PureComponent<
/> />
); );
} }
if (query.trackingUrl) {
trackingUrl = (
<Button
buttonSize="small"
onClick={() => query.trackingUrl && window.open(query.trackingUrl)}
>
{t('Track job')}
</Button>
);
}
const progressMsg = const progressMsg =
query && query.extra && query.extra.progress query && query.extra && query.extra.progress
? query.extra.progress ? query.extra.progress

View File

@ -16,13 +16,15 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
import React, { CSSProperties, Children, ReactElement } from 'react'; import React, { Children, ReactElement } from 'react';
import { kebabCase } from 'lodash'; import { kebabCase } from 'lodash';
import { mix } from 'polished'; import { mix } from 'polished';
import cx from 'classnames'; import cx from 'classnames';
import { AntdButton } from 'src/components'; import { AntdButton } from 'src/components';
import { useTheme } from '@superset-ui/core'; import { useTheme } from '@superset-ui/core';
import { Tooltip } from 'src/components/Tooltip'; import { Tooltip } from 'src/components/Tooltip';
import { ButtonProps as AntdButtonProps } from 'antd/lib/button';
import { TooltipProps } from 'antd/lib/tooltip';
export type OnClickHandler = React.MouseEventHandler<HTMLElement>; export type OnClickHandler = React.MouseEventHandler<HTMLElement>;
@ -37,45 +39,15 @@ export type ButtonStyle =
| 'link' | 'link'
| 'dashed'; | 'dashed';
export interface ButtonProps { export type ButtonProps = Omit<AntdButtonProps, 'css'> &
id?: string; Pick<TooltipProps, 'placement'> & {
className?: string; tooltip?: string;
tooltip?: string; className?: string;
ghost?: boolean; buttonSize?: 'default' | 'small' | 'xsmall';
placement?: buttonStyle?: ButtonStyle;
| 'bottom' cta?: boolean;
| 'left' showMarginRight?: boolean;
| 'right' };
| 'top'
| 'topLeft'
| 'topRight'
| 'bottomLeft'
| 'bottomRight'
| 'leftTop'
| 'leftBottom'
| 'rightTop'
| 'rightBottom';
onClick?: OnClickHandler;
onMouseDown?: OnClickHandler;
disabled?: boolean;
buttonStyle?: ButtonStyle;
buttonSize?: 'default' | 'small' | 'xsmall';
style?: CSSProperties;
children?: React.ReactNode;
href?: string;
htmlType?: 'button' | 'submit' | 'reset';
cta?: boolean;
loading?: boolean | { delay?: number | undefined } | undefined;
showMarginRight?: boolean;
type?:
| 'default'
| 'text'
| 'link'
| 'primary'
| 'dashed'
| 'ghost'
| undefined;
}
export default function Button(props: ButtonProps) { export default function Button(props: ButtonProps) {
const { const {

View File

@ -995,7 +995,13 @@ BLUEPRINTS: List[Blueprint] = []
# into a proxied one # into a proxied one
TRACKING_URL_TRANSFORMER = lambda x: x # Transform SQL query tracking url for Hive and Presto engines. You may also
# access information about the query itself by adding a second parameter
# to your transformer function, e.g.:
# TRACKING_URL_TRANSFORMER = (
# lambda url, query: url if is_fresh(query) else None
# )
TRACKING_URL_TRANSFORMER = lambda url: url
# Interval between consecutive polls when using Hive Engine # Interval between consecutive polls when using Hive Engine

View File

@ -315,7 +315,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return int(progress) return int(progress)
@classmethod @classmethod
def get_tracking_url(cls, log_lines: List[str]) -> Optional[str]: def get_tracking_url_from_logs(cls, log_lines: List[str]) -> Optional[str]:
lkp = "Tracking URL = " lkp = "Tracking URL = "
for line in log_lines: for line in log_lines:
if lkp in line: if lkp in line:
@ -366,7 +366,7 @@ class HiveEngineSpec(PrestoEngineSpec):
query.progress = progress query.progress = progress
needs_commit = True needs_commit = True
if not tracking_url: if not tracking_url:
tracking_url = cls.get_tracking_url(log_lines) tracking_url = cls.get_tracking_url_from_logs(log_lines)
if tracking_url: if tracking_url:
job_id = tracking_url.split("/")[-2] job_id = tracking_url.split("/")[-2]
logger.info( logger.info(
@ -374,13 +374,6 @@ class HiveEngineSpec(PrestoEngineSpec):
str(query_id), str(query_id),
tracking_url, tracking_url,
) )
transformer = current_app.config["TRACKING_URL_TRANSFORMER"]
tracking_url = transformer(tracking_url)
logger.info(
"Query %s: Transformation applied: %s",
str(query_id),
tracking_url,
)
query.tracking_url = tracking_url query.tracking_url = tracking_url
logger.info("Query %s: Job id: %s", str(query_id), str(job_id)) logger.info("Query %s: Job id: %s", str(query_id), str(job_id))
needs_commit = True needs_commit = True

View File

@ -64,6 +64,12 @@ if TYPE_CHECKING:
# prevent circular imports # prevent circular imports
from superset.models.core import Database from superset.models.core import Database
# need try/catch because pyhive may not be installed
try:
from pyhive.presto import Cursor # pylint: disable=unused-import
except ImportError:
pass
COLUMN_DOES_NOT_EXIST_REGEX = re.compile( COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
"line (?P<location>.+?): .*Column '(?P<column_name>.+?)' cannot be resolved" "line (?P<location>.+?): .*Column '(?P<column_name>.+?)' cannot be resolved"
) )
@ -957,8 +963,23 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
return rows[0][0] return rows[0][0]
@classmethod @classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
try:
if cursor.last_query_id:
# pylint: disable=protected-access, line-too-long
return f"{cursor._protocol}://{cursor._host}:{cursor._port}/ui/query.html?{cursor.last_query_id}"
except AttributeError:
pass
return None
@classmethod
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
"""Updates progress information""" """Updates progress information"""
tracking_url = cls.get_tracking_url(cursor)
if tracking_url:
query.tracking_url = tracking_url
session.commit()
query_id = query.id query_id = query.id
poll_interval = query.database.connect_args.get( poll_interval = query.database.connect_args.get(
"poll_interval", current_app.config["PRESTO_POLL_INTERVAL"] "poll_interval", current_app.config["PRESTO_POLL_INTERVAL"]

View File

@ -32,6 +32,11 @@ from superset.utils import core as utils
if TYPE_CHECKING: if TYPE_CHECKING:
from superset.models.core import Database from superset.models.core import Database
try:
from trino.dbapi import Cursor # pylint: disable=unused-import
except ImportError:
pass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -109,8 +114,25 @@ class TrinoEngineSpec(PrestoEngineSpec):
) )
@classmethod @classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None: def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
try:
return cursor.info_uri
except AttributeError:
try:
conn = cursor.connection
# pylint: disable=protected-access, line-too-long
return f"{conn.http_scheme}://{conn.host}:{conn.port}/ui/query.html?{cursor._query.query_id}"
except AttributeError:
pass
return None
@classmethod
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
"""Updates progress information""" """Updates progress information"""
tracking_url = cls.get_tracking_url(cursor)
if tracking_url:
query.tracking_url = tracking_url
session.commit()
BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session) BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session)
@staticmethod @staticmethod

View File

@ -15,13 +15,15 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""A collection of ORM sqlalchemy models for SQL Lab""" """A collection of ORM sqlalchemy models for SQL Lab"""
import inspect
import logging
import re import re
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
import simplejson as json import simplejson as json
import sqlalchemy as sqla import sqlalchemy as sqla
from flask import Markup from flask import current_app, Markup
from flask_appbuilder import Model from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders from flask_appbuilder.models.decorators import renders
from humanize import naturaltime from humanize import naturaltime
@ -56,6 +58,9 @@ if TYPE_CHECKING:
from superset.db_engine_specs import BaseEngineSpec from superset.db_engine_specs import BaseEngineSpec
logger = logging.getLogger(__name__)
class Query(Model, ExtraJSONMixin, ExploreMixin): # pylint: disable=abstract-method class Query(Model, ExtraJSONMixin, ExploreMixin): # pylint: disable=abstract-method
"""ORM model for SQL query """ORM model for SQL query
@ -104,7 +109,7 @@ class Query(Model, ExtraJSONMixin, ExploreMixin): # pylint: disable=abstract-me
start_running_time = Column(Numeric(precision=20, scale=6)) start_running_time = Column(Numeric(precision=20, scale=6))
end_time = Column(Numeric(precision=20, scale=6)) end_time = Column(Numeric(precision=20, scale=6))
end_result_backend_time = Column(Numeric(precision=20, scale=6)) end_result_backend_time = Column(Numeric(precision=20, scale=6))
tracking_url = Column(Text) tracking_url_raw = Column(Text, name="tracking_url")
changed_on = Column( changed_on = Column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=True DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=True
@ -283,6 +288,27 @@ class Query(Model, ExtraJSONMixin, ExploreMixin): # pylint: disable=abstract-me
def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]: def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]:
return [] return []
@property
def tracking_url(self) -> Optional[str]:
"""
Transfrom tracking url at run time because the exact URL may depends
on query properties such as execution and finish time.
"""
transform = current_app.config.get("TRACKING_URL_TRANSFORMER")
url = self.tracking_url_raw
if url and transform:
sig = inspect.signature(transform)
# for backward compatibility, users may define a transformer function
# with only one parameter (`url`).
args = [url, self][: len(sig.parameters)]
url = transform(*args)
logger.debug("Transformed tracking url: %s", url)
return url
@tracking_url.setter
def tracking_url(self, value: str) -> None:
self.tracking_url_raw = value
class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
"""ORM model for SQL query""" """ORM model for SQL query"""

View File

@ -96,8 +96,13 @@ def handle_query_error(
msg = f"{prefix_message} {str(ex)}".strip() msg = f"{prefix_message} {str(ex)}".strip()
troubleshooting_link = config["TROUBLESHOOTING_LINK"] troubleshooting_link = config["TROUBLESHOOTING_LINK"]
query.error_message = msg query.error_message = msg
query.status = QueryStatus.FAILED
query.tmp_table_name = None query.tmp_table_name = None
query.status = QueryStatus.FAILED
# TODO: re-enable this after updating the frontend to properly display timeout status
# if query.status != QueryStatus.TIMED_OUT:
# query.status = QueryStatus.FAILED
if not query.end_time:
query.end_time = now_as_float()
# extract DB-specific errors (invalid column, eg) # extract DB-specific errors (invalid column, eg)
if isinstance(ex, SupersetErrorException): if isinstance(ex, SupersetErrorException):
@ -286,6 +291,8 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem
# return 1 row less than increased_query # return 1 row less than increased_query
data = data[:-1] data = data[:-1]
except SoftTimeLimitExceeded as ex: except SoftTimeLimitExceeded as ex:
query.status = QueryStatus.TIMED_OUT
logger.warning("Query %d: Time limit exceeded", query.id) logger.warning("Query %d: Time limit exceeded", query.id)
logger.debug("Query %d: %s", query.id, ex) logger.debug("Query %d: %s", query.id, ex)
raise SupersetErrorException( raise SupersetErrorException(

View File

@ -25,7 +25,7 @@ from superset.exceptions import SupersetException
MSG_FORMAT = "Failed to execute {}" MSG_FORMAT = "Failed to execute {}"
if TYPE_CHECKING: if TYPE_CHECKING:
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext from superset.sqllab.sqllab_execution_context import SqlJsonExecutionContext
class SqlLabException(SupersetException): class SqlLabException(SupersetException):

View File

@ -22,6 +22,7 @@ EPOCH = datetime(1970, 1, 1)
def datetime_to_epoch(dttm: datetime) -> float: def datetime_to_epoch(dttm: datetime) -> float:
"""Convert datetime to milliseconds to epoch"""
if dttm.tzinfo: if dttm.tzinfo:
dttm = dttm.replace(tzinfo=pytz.utc) dttm = dttm.replace(tzinfo=pytz.utc)
epoch_with_tz = pytz.utc.localize(EPOCH) epoch_with_tz = pytz.utc.localize(EPOCH)

View File

@ -2322,6 +2322,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
raise SupersetCancelQueryException("Could not cancel query") raise SupersetCancelQueryException("Could not cancel query")
query.status = QueryStatus.STOPPED query.status = QueryStatus.STOPPED
query.end_time = now_as_float()
db.session.commit() db.session.commit()
return self.json_response("OK") return self.json_response("OK")

View File

@ -16,15 +16,18 @@
# under the License. # under the License.
from __future__ import annotations from __future__ import annotations
import contextlib
import functools import functools
from operator import ge
from typing import Any, Callable, Optional, TYPE_CHECKING from typing import Any, Callable, Optional, TYPE_CHECKING
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from flask.ctx import AppContext from flask.ctx import AppContext
from flask_appbuilder.security.sqla import models as ab_models
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from superset import db from superset import db, security_manager
from superset.extensions import feature_flag_manager from superset.extensions import feature_flag_manager
from superset.utils.core import json_dumps_w_dates from superset.utils.core import json_dumps_w_dates
from superset.utils.database import get_example_database, remove_database from superset.utils.database import get_example_database, remove_database
@ -68,6 +71,50 @@ def login_as_admin(login_as: Callable[..., None]):
yield login_as("admin") yield login_as("admin")
@pytest.fixture
def create_user(app_context: AppContext):
def _create_user(username: str, role: str = "Admin", password: str = "general"):
security_manager.add_user(
username,
"firstname",
"lastname",
"email@exaple.com",
security_manager.find_role(role),
password,
)
return security_manager.find_user(username)
return _create_user
@pytest.fixture
def get_user(app_context: AppContext):
def _get_user(username: str) -> ab_models.User:
return (
db.session.query(security_manager.user_model)
.filter_by(username=username)
.one_or_none()
)
return _get_user
@pytest.fixture
def get_or_create_user(get_user, create_user) -> ab_models.User:
@contextlib.contextmanager
def _get_user(username: str) -> ab_models.User:
user = get_user(username)
if not user:
# if user is created by test, remove it after done
user = create_user(username)
yield user
db.session.delete(user)
else:
yield user
return _get_user
@pytest.fixture(autouse=True, scope="session") @pytest.fixture(autouse=True, scope="session")
def setup_sample_data() -> Any: def setup_sample_data() -> Any:
# TODO(john-bodley): Determine a cleaner way of setting up the sample data without # TODO(john-bodley): Determine a cleaner way of setting up the sample data without

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Callable, ContextManager
import pytest
from flask_appbuilder.security.sqla import models as ab_models
from superset import db
from superset.models.sql_lab import Query
from superset.utils.core import shortid
from superset.utils.database import get_example_database
def force_async_run(allow_run_async: bool):
example_db = get_example_database()
orig_allow_run_async = example_db.allow_run_async
example_db.allow_run_async = allow_run_async
db.session.commit()
yield example_db
example_db.allow_run_async = orig_allow_run_async
db.session.commit()
@pytest.fixture
def non_async_example_db(app_context):
gen = force_async_run(False)
yield next(gen)
try:
next(gen)
except StopIteration:
pass
@pytest.fixture
def async_example_db(app_context):
gen = force_async_run(True)
yield next(gen)
try:
next(gen)
except StopIteration:
pass
@pytest.fixture
def example_query(get_or_create_user: Callable[..., ContextManager[ab_models.User]]):
with get_or_create_user("sqllab-test-user") as user:
query = Query(
client_id=shortid()[:10], database=get_example_database(), user=user
)
db.session.add(query)
db.session.commit()
yield query
db.session.delete(query)
db.session.commit()

View File

@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset import app, db
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.sql_lab import execute_sql_statements
from superset.utils.dates import now_as_float
def test_non_async_execute(non_async_example_db: Database, example_query: Query):
"""Test query.tracking_url is attached for Presto and Hive queries"""
result = execute_sql_statements(
example_query.id,
"select 1 as foo;",
store_results=False,
return_results=True,
session=db.session,
start_time=now_as_float(),
expand_data=True,
log_params=dict(),
)
assert result
assert result["query_id"] == example_query.id
assert result["status"] == QueryStatus.SUCCESS
assert result["data"] == [{"foo": 1}]
# should attach apply tracking URL for Presto & Hive
if non_async_example_db.db_engine_spec.engine == "presto":
assert example_query.tracking_url
assert "/ui/query.html?" in example_query.tracking_url
app.config["TRACKING_URL_TRANSFORMER"] = lambda url, query: url.replace(
"/ui/query.html?", f"/{query.client_id}/"
)
assert f"/{example_query.client_id}/" in example_query.tracking_url
app.config["TRACKING_URL_TRANSFORMER"] = lambda url: url + "&foo=bar"
assert example_query.tracking_url.endswith("&foo=bar")
if non_async_example_db.db_engine_spec.engine_name == "hive":
assert example_query.tracking_url_raw

View File

@ -18,6 +18,7 @@
"""Unit tests for Sql Lab""" """Unit tests for Sql Lab"""
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from math import ceil, floor
import pytest import pytest
from celery.exceptions import SoftTimeLimitExceeded from celery.exceptions import SoftTimeLimitExceeded
@ -70,8 +71,8 @@ class TestSqlLab(SupersetTestCase):
db.session.query(Query).delete() db.session.query(Query).delete()
db.session.commit() db.session.commit()
self.run_sql(QUERY_1, client_id="client_id_1", username="admin") self.run_sql(QUERY_1, client_id="client_id_1", username="admin")
self.run_sql(QUERY_2, client_id="client_id_3", username="admin") self.run_sql(QUERY_2, client_id="client_id_2", username="admin")
self.run_sql(QUERY_3, client_id="client_id_2", username="gamma_sqllab") self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab")
self.logout() self.logout()
def tearDown(self): def tearDown(self):
@ -406,22 +407,17 @@ class TestSqlLab(SupersetTestCase):
self.assertEqual(2, len(data)) self.assertEqual(2, len(data))
self.assertIn("birth", data[0]["sql"]) self.assertIn("birth", data[0]["sql"])
def test_search_query_on_time(self): def test_search_query_filter_by_time(self):
self.run_some_queries() self.run_some_queries()
self.login("admin") self.login("admin")
first_query_time = ( from_time = floor(
db.session.query(Query).filter_by(sql=QUERY_1).one() (db.session.query(Query).filter_by(sql=QUERY_1).one()).start_time
).start_time )
second_query_time = ( to_time = ceil(
db.session.query(Query).filter_by(sql=QUERY_3).one() (db.session.query(Query).filter_by(sql=QUERY_2).one()).start_time
).start_time )
# Test search queries on time filter url = f"/superset/search_queries?from={from_time}&to={to_time}"
from_time = "from={}".format(int(first_query_time)) assert len(self.client.get(url).json) == 2
to_time = "to={}".format(int(second_query_time))
params = [from_time, to_time]
resp = self.get_resp("/superset/search_queries?" + "&".join(params))
data = json.loads(resp)
self.assertEqual(2, len(data))
def test_search_query_only_owned(self) -> None: def test_search_query_only_owned(self) -> None:
""" """