diff --git a/superset/app.py b/superset/app.py index 5dc06840f..ad330f3f8 100644 --- a/superset/app.py +++ b/superset/app.py @@ -24,7 +24,6 @@ from flask import Flask, redirect from flask_appbuilder import expose, IndexView from flask_babel import gettext as __, lazy_gettext as _ from flask_compress import Compress -from flask_wtf import CSRFProtect from superset.connectors.connector_registry import ConnectorRegistry from superset.extensions import ( @@ -33,6 +32,7 @@ from superset.extensions import ( appbuilder, cache_manager, celery_app, + csrf, db, feature_flag_manager, jinja_context_manager, @@ -614,7 +614,7 @@ class SupersetAppInitializer: def configure_wtf(self) -> None: if self.config["WTF_CSRF_ENABLED"]: - csrf = CSRFProtect(self.flask_app) + csrf.init_app(self.flask_app) csrf_exempt_list = self.config["WTF_CSRF_EXEMPT_LIST"] for ex in csrf_exempt_list: csrf.exempt(ex) diff --git a/superset/config.py b/superset/config.py index d9ae516b8..d22fd80ab 100644 --- a/superset/config.py +++ b/superset/config.py @@ -171,7 +171,7 @@ QUERY_SEARCH_LIMIT = 1000 WTF_CSRF_ENABLED = True # Add endpoints that need to be exempt from CSRF protection -WTF_CSRF_EXEMPT_LIST = ["superset.views.core.log"] +WTF_CSRF_EXEMPT_LIST = ["superset.views.core.log", "superset.charts.api.data"] # Whether to run the web server in debug mode or not DEBUG = os.environ.get("FLASK_ENV") == "development" diff --git a/superset/extensions.py b/superset/extensions.py index a0dad8155..7cafef61a 100644 --- a/superset/extensions.py +++ b/superset/extensions.py @@ -29,6 +29,7 @@ from flask import Flask from flask_appbuilder import AppBuilder, SQLA from flask_migrate import Migrate from flask_talisman import Talisman +from flask_wtf.csrf import CSRFProtect from werkzeug.local import LocalProxy from superset.utils.cache_manager import CacheManager @@ -132,6 +133,7 @@ APP_DIR = os.path.dirname(__file__) appbuilder = AppBuilder(update_perms=False) cache_manager = CacheManager() celery_app = celery.Celery() +csrf = CSRFProtect() db = SQLA() _event_logger: Dict[str, Any] = {} event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))