diff --git a/superset/security/manager.py b/superset/security/manager.py index 47b9bc2d9..a7a84648f 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -39,6 +39,7 @@ from flask_appbuilder.security.views import ( ViewMenuModelView, ) from flask_appbuilder.widgets import ListWidget +from flask_login import AnonymousUserMixin from sqlalchemy import and_, or_ from sqlalchemy.engine.base import Connection from sqlalchemy.orm import Session @@ -1024,6 +1025,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods .one_or_none() ) + def get_anonymous_user(self) -> User: # pylint: disable=no-self-use + return AnonymousUserMixin() + def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: """ Retrieves the appropriate row level security filters for the current user and diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 9fa0eb488..5fbe39e15 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -40,9 +40,11 @@ query_timeout = current_app.config[ def ensure_user_is_set(user_id: Optional[int]) -> None: - user_is_set = hasattr(g, "user") and g.user is not None - if not user_is_set and user_id is not None: + user_is_not_set = not (hasattr(g, "user") and g.user is not None) + if user_is_not_set and user_id is not None: g.user = security_manager.get_user_by_id(user_id) + elif user_is_not_set: + g.user = security_manager.get_anonymous_user() @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) diff --git a/tests/tasks/async_queries_tests.py b/tests/tasks/async_queries_tests.py index cca58d8e2..6914854ca 100644 --- a/tests/tasks/async_queries_tests.py +++ b/tests/tasks/async_queries_tests.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. """Unit tests for async query celery jobs in Superset""" -import re from unittest import mock from uuid import uuid4 import pytest from celery.exceptions import SoftTimeLimitExceeded +from flask import g from superset import db from superset.charts.commands.data import ChartDataCommand @@ -30,6 +30,7 @@ from superset.exceptions import SupersetException from superset.extensions import async_query_manager, security_manager from superset.tasks import async_queries from superset.tasks.async_queries import ( + ensure_user_is_set, load_chart_data_into_cache, load_explore_json_into_cache, ) @@ -202,3 +203,44 @@ class TestAsyncQueries(SupersetTestCase): ensure_user_is_set.side_effect = SoftTimeLimitExceeded() load_explore_json_into_cache(job_metadata, form_data) ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors) + + def test_ensure_user_is_set(self): + g_user_is_set = hasattr(g, "user") + original_g_user = g.user if g_user_is_set else None + + if g_user_is_set: + del g.user + + self.assertFalse(hasattr(g, "user")) + ensure_user_is_set(1) + self.assertTrue(hasattr(g, "user")) + self.assertFalse(g.user.is_anonymous) + self.assertEqual("1", g.user.get_id()) + + del g.user + + self.assertFalse(hasattr(g, "user")) + ensure_user_is_set(None) + self.assertTrue(hasattr(g, "user")) + self.assertTrue(g.user.is_anonymous) + self.assertEqual(None, g.user.get_id()) + + del g.user + + g.user = security_manager.get_user_by_id(2) + self.assertEqual("2", g.user.get_id()) + + ensure_user_is_set(1) + self.assertTrue(hasattr(g, "user")) + self.assertFalse(g.user.is_anonymous) + self.assertEqual("2", g.user.get_id()) + + ensure_user_is_set(None) + self.assertTrue(hasattr(g, "user")) + self.assertFalse(g.user.is_anonymous) + self.assertEqual("2", g.user.get_id()) + + if g_user_is_set: + g.user = original_g_user + else: + del g.user