diff --git a/superset/views/base.py b/superset/views/base.py index 0243fb1e2..fce500e79 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -170,64 +170,66 @@ class BaseSupersetView(BaseView): mimetype="application/json", ) - def menu_data(self): - menu = appbuilder.menu.get_data() - root_path = "#" - logo_target_path = "" - if not g.user.is_anonymous: - try: - logo_target_path = ( - appbuilder.app.config.get("LOGO_TARGET_PATH") - or f"/profile/{g.user.username}/" - ) - # when user object has no username - except NameError as e: - logging.exception(e) - if logo_target_path.startswith("/"): - root_path = f"/superset{logo_target_path}" - else: - root_path = logo_target_path +def menu_data(): + menu = appbuilder.menu.get_data() + root_path = "#" + logo_target_path = "" + if not g.user.is_anonymous: + try: + logo_target_path = ( + appbuilder.app.config.get("LOGO_TARGET_PATH") + or f"/profile/{g.user.username}/" + ) + # when user object has no username + except NameError as e: + logging.exception(e) - languages = {} - for lang in appbuilder.languages: - languages[lang] = { - **appbuilder.languages[lang], - "url": appbuilder.get_url_for_locale(lang), - } - return { - "menu": menu, - "brand": { - "path": root_path, - "icon": appbuilder.app_icon, - "alt": appbuilder.app_name, - }, - "navbar_right": { - "bug_report_url": appbuilder.app.config.get("BUG_REPORT_URL"), - "documentation_url": appbuilder.app.config.get("DOCUMENTATION_URL"), - "languages": languages, - "show_language_picker": len(languages.keys()) > 1, - "user_is_anonymous": g.user.is_anonymous, - "user_info_url": appbuilder.get_url_for_userinfo, - "user_logout_url": appbuilder.get_url_for_logout, - "user_login_url": appbuilder.get_url_for_login, - "locale": session.get("locale", "en"), - }, + if logo_target_path.startswith("/"): + root_path = f"/superset{logo_target_path}" + else: + root_path = logo_target_path + + languages = {} + for lang in appbuilder.languages: + languages[lang] = { + **appbuilder.languages[lang], + "url": appbuilder.get_url_for_locale(lang), } + return { + "menu": menu, + "brand": { + "path": root_path, + "icon": appbuilder.app_icon, + "alt": appbuilder.app_name, + }, + "navbar_right": { + "bug_report_url": appbuilder.app.config.get("BUG_REPORT_URL"), + "documentation_url": appbuilder.app.config.get("DOCUMENTATION_URL"), + "languages": languages, + "show_language_picker": len(languages.keys()) > 1, + "user_is_anonymous": g.user.is_anonymous, + "user_info_url": appbuilder.get_url_for_userinfo, + "user_logout_url": appbuilder.get_url_for_logout, + "user_login_url": appbuilder.get_url_for_login, + "locale": session.get("locale", "en"), + }, + } - def common_bootstrap_payload(self): - """Common data always sent to the client""" - messages = get_flashed_messages(with_categories=True) - locale = str(get_locale()) - return { - "flash_messages": messages, - "conf": {k: conf.get(k) for k in FRONTEND_CONF_KEYS}, - "locale": locale, - "language_pack": get_language_pack(locale), - "feature_flags": get_feature_flags(), - "menu_data": self.menu_data(), - } +def common_bootstrap_payload(): + """Common data always sent to the client""" + messages = get_flashed_messages(with_categories=True) + locale = str(get_locale()) + + return { + "flash_messages": messages, + "conf": {k: conf.get(k) for k in FRONTEND_CONF_KEYS}, + "locale": locale, + "language_pack": get_language_pack(locale), + "feature_flags": get_feature_flags(), + "menu_data": menu_data(), + } class SupersetListWidget(ListWidget): diff --git a/superset/views/core.py b/superset/views/core.py index 958f6421d..dc9bec6aa 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -20,7 +20,7 @@ import re from contextlib import closing from datetime import datetime, timedelta from enum import Enum -from typing import cast, List, Optional, Union +from typing import Any, cast, Dict, List, Optional, Union from urllib import parse import backoff @@ -45,7 +45,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access, has_access_api from flask_appbuilder.security.sqla import models as ab_models from flask_babel import gettext as __, lazy_gettext as _ -from sqlalchemy import and_, or_, select +from sqlalchemy import and_, Integer, or_, select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.session import Session from werkzeug.routing import BaseConverter @@ -91,6 +91,7 @@ from .base import ( BaseFilter, BaseSupersetView, check_ownership, + common_bootstrap_payload, CsvResponse, data_payload_response, DeleteMixin, @@ -1292,7 +1293,7 @@ class Superset(BaseSupersetView): "standalone": standalone, "user_id": user_id, "forced_height": request.args.get("height"), - "common": self.common_bootstrap_payload(), + "common": common_bootstrap_payload(), } table_name = ( datasource.table_name @@ -2229,7 +2230,7 @@ class Superset(BaseSupersetView): "user_id": g.user.get_id(), "dashboard_data": dashboard_data, "datasources": {ds.uid: ds.data for ds in datasources}, - "common": self.common_bootstrap_payload(), + "common": common_bootstrap_payload(), "editMode": edit_mode, "urlParams": url_params, } @@ -3032,7 +3033,7 @@ class Superset(BaseSupersetView): payload = { "user": bootstrap_user_data(g.user), - "common": self.common_bootstrap_payload(), + "common": common_bootstrap_payload(), } return self.render_template( @@ -3058,7 +3059,7 @@ class Superset(BaseSupersetView): payload = { "user": bootstrap_user_data(user, include_perms=True), - "common": self.common_bootstrap_payload(), + "common": common_bootstrap_payload(), } return self.render_template( @@ -3070,27 +3071,25 @@ class Superset(BaseSupersetView): ), ) - @has_access - @expose("/sqllab") - def sqllab(self): - """SQL Editor""" - + @staticmethod + def _get_sqllab_payload(user_id: int) -> Dict[str, Any]: # send list of tab state ids - tab_state_ids = ( + tabs_state = ( db.session.query(TabState.id, TabState.label) - .filter_by(user_id=g.user.get_id()) + .filter_by(user_id=user_id) .all() ) + tab_state_ids = [tab_state[0] for tab_state in tabs_state] # return first active tab, or fallback to another one if no tab is active active_tab = ( db.session.query(TabState) - .filter_by(user_id=g.user.get_id()) + .filter_by(user_id=user_id) .order_by(TabState.active.desc()) .first() ) - databases = {} - queries = {} + databases: Dict[int, Any] = {} + queries: Dict[str, Any] = {} # These are unnecessary if sqllab backend persistence is disabled if is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE"): @@ -3100,26 +3099,38 @@ class Superset(BaseSupersetView): } for database in db.session.query(models.Database).all() } + # return all user queries associated with existing SQL editors user_queries = ( - db.session.query(Query).filter_by(user_id=g.user.get_id()).all() + db.session.query(Query) + .filter_by(user_id=user_id) + .filter(Query.sql_editor_id.cast(Integer).in_(tab_state_ids)) + .all() ) queries = { query.client_id: {k: v for k, v in query.to_dict().items()} for query in user_queries } - d = { + return { "defaultDbId": config["SQLLAB_DEFAULT_DBID"], - "common": self.common_bootstrap_payload(), - "tab_state_ids": tab_state_ids, + "common": common_bootstrap_payload(), + "tab_state_ids": tabs_state, "active_tab": active_tab.to_dict() if active_tab else None, "databases": databases, "queries": queries, } + + @has_access + @expose("/sqllab") + def sqllab(self): + """SQL Editor""" + payload = self._get_sqllab_payload(g.user.get_id()) + bootstrap_data = json.dumps( + payload, default=utils.pessimistic_json_iso_dttm_ser + ) + return self.render_template( - "superset/basic.html", - entry="sqllab", - bootstrap_data=json.dumps(d, default=utils.pessimistic_json_iso_dttm_ser), + "superset/basic.html", entry="sqllab", bootstrap_data=bootstrap_data ) @api diff --git a/tests/base_tests.py b/tests/base_tests.py index 7399774cc..0f32872ad 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -193,6 +193,7 @@ class SupersetTestCase(TestCase): raise_on_error=False, query_limit=None, database_name="examples", + sql_editor_id=None, ): if user_name: self.logout() @@ -207,6 +208,7 @@ class SupersetTestCase(TestCase): select_as_create_as=False, client_id=client_id, queryLimit=query_limit, + sql_editor_id=sql_editor_id, ), ) if raise_on_error and "error" in resp: diff --git a/tests/core_tests.py b/tests/core_tests.py index b227c5853..ae7cf633e 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -983,6 +983,53 @@ class CoreTests(SupersetTestCase): data = self.get_resp(url) self.assertTrue(html in data) + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"SQLLAB_BACKEND_PERSISTENCE": True}, + clear=True, + ) + def test_sqllab_backend_persistence_payload(self): + username = "admin" + self.login(username) + user_id = security_manager.find_user(username).id + + # create a tab + data = { + "queryEditor": json.dumps( + { + "title": "Untitled Query 1", + "dbId": 1, + "schema": None, + "autorun": False, + "sql": "SELECT ...", + "queryLimit": 1000, + } + ) + } + resp = self.get_json_resp("/tabstateview/", data=data) + tab_state_id = resp["id"] + + # run a query in the created tab + self.run_sql( + "SELECT name FROM birth_names", + "client_id_1", + user_name=username, + raise_on_error=True, + sql_editor_id=tab_state_id, + ) + # run an orphan query (no tab) + self.run_sql( + "SELECT name FROM birth_names", + "client_id_2", + user_name=username, + raise_on_error=True, + ) + + # we should have only 1 query returned, since the second one is not + # associated with any tabs + payload = views.Superset._get_sqllab_payload(user_id=user_id) + self.assertEqual(len(payload["queries"]), 1) + if __name__ == "__main__": unittest.main()