chore: use before_request hook for dynamic routes (#14568)

* chore: use before_request hook for dynamic routes

* Shorten hook names

* Introduce with_feature_flags and update thumbnail tests

* Disable test that fails in CI but not locally

* Add test for reports
This commit is contained in:
Ben Reinhart 2021-05-14 12:49:25 -07:00 committed by GitHub
parent f16c708fab
commit 6d9d362ca8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 168 additions and 83 deletions

View File

@ -221,9 +221,8 @@ class SupersetAppInitializer:
appbuilder.add_api(DatasetMetricRestApi)
appbuilder.add_api(QueryRestApi)
appbuilder.add_api(SavedQueryRestApi)
if feature_flag_manager.is_feature_enabled("ALERT_REPORTS"):
appbuilder.add_api(ReportScheduleRestApi)
appbuilder.add_api(ReportExecutionLogRestApi)
appbuilder.add_api(ReportScheduleRestApi)
appbuilder.add_api(ReportExecutionLogRestApi)
#
# Setup regular views
#

View File

@ -18,12 +18,13 @@ import json
import logging
from datetime import datetime
from io import BytesIO
from typing import Any, Dict
from typing import Any, Dict, Optional
from zipfile import ZipFile
import simplejson
from flask import g, make_response, redirect, request, Response, send_file, url_for
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.hooks import before_request
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import gettext as _, ngettext
from marshmallow import ValidationError
@ -94,6 +95,12 @@ class ChartRestApi(BaseSupersetModelRestApi):
resource_name = "chart"
allow_browser_login = True
@before_request(only=["thumbnail", "screenshot", "cache_screenshot"])
def ensure_thumbnails_enabled(self) -> Optional[Response]:
if not is_feature_enabled("THUMBNAILS"):
return self.response_404()
return None
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
RouteMethod.EXPORT,
RouteMethod.IMPORT,
@ -103,6 +110,9 @@ class ChartRestApi(BaseSupersetModelRestApi):
"data_from_cache",
"viz_types",
"favorite_status",
"thumbnail",
"screenshot",
"cache_screenshot",
}
class_permission_name = "Chart"
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
@ -212,15 +222,6 @@ class ChartRestApi(BaseSupersetModelRestApi):
allowed_rel_fields = {"owners", "created_by"}
def __init__(self) -> None:
if is_feature_enabled("THUMBNAILS"):
self.include_route_methods = self.include_route_methods | {
"thumbnail",
"screenshot",
"cache_screenshot",
}
super().__init__()
@expose("/", methods=["POST"])
@protect()
@safe

View File

@ -18,11 +18,12 @@ import json
import logging
from datetime import datetime
from io import BytesIO
from typing import Any, Dict
from typing import Any, Dict, Optional
from zipfile import is_zipfile, ZipFile
from flask import g, make_response, redirect, request, Response, send_file, url_for
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.hooks import before_request
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
from marshmallow import ValidationError
@ -88,6 +89,13 @@ logger = logging.getLogger(__name__)
class DashboardRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(Dashboard)
@before_request(only=["thumbnail"])
def ensure_thumbnails_enabled(self) -> Optional[Response]:
if not is_feature_enabled("THUMBNAILS"):
return self.response_404()
return None
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
RouteMethod.EXPORT,
RouteMethod.IMPORT,
@ -96,6 +104,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
"favorite_status",
"get_charts",
"get_datasets",
"thumbnail",
}
resource_name = "dashboard"
allow_browser_login = True
@ -206,11 +215,6 @@ class DashboardRestApi(BaseSupersetModelRestApi):
openapi_spec_methods = openapi_spec_methods_override
""" Overrides GET methods OpenApi descriptions """
def __init__(self) -> None:
if is_feature_enabled("THUMBNAILS"):
self.include_route_methods = self.include_route_methods | {"thumbnail"}
super().__init__()
def __repr__(self) -> str:
"""Deterministic string representation of the API instance for etag_cache."""
return "Superset.dashboards.api.DashboardRestApi@v{}{}".format(

View File

@ -15,14 +15,16 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
from typing import Any, Optional
from flask import g, request, Response
from flask_appbuilder.api import expose, permission_name, protect, rison, safe
from flask_appbuilder.hooks import before_request
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
from marshmallow import ValidationError
from superset import is_feature_enabled
from superset.charts.filters import ChartFilter
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.dashboards.filters import DashboardAccessFilter
@ -61,6 +63,12 @@ logger = logging.getLogger(__name__)
class ReportScheduleRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(ReportSchedule)
@before_request
def ensure_alert_reports_enabled(self) -> Optional[Response]:
if not is_feature_enabled("ALERT_REPORTS"):
return self.response_404()
return None
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
RouteMethod.RELATED,
"bulk_delete", # not using RouteMethod since locally defined

View File

@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional
from flask import Response
from flask_appbuilder.api import expose, permission_name, protect, rison, safe
from flask_appbuilder.api.schemas import get_item_schema, get_list_schema
from flask_appbuilder.hooks import before_request
from flask_appbuilder.models.sqla.interface import SQLAInterface
from superset import is_feature_enabled
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.models.reports import ReportExecutionLog
from superset.reports.logs.schemas import openapi_spec_methods_override
@ -33,6 +35,12 @@ logger = logging.getLogger(__name__)
class ReportExecutionLogRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(ReportExecutionLog)
@before_request
def ensure_alert_reports_enabled(self) -> Optional[Response]:
if not is_feature_enabled("ALERT_REPORTS"):
return self.response_404()
return None
include_route_methods = {RouteMethod.GET, RouteMethod.GET_LIST}
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP

View File

@ -20,13 +20,14 @@ from typing import Any
from flask import current_app, url_for
def get_url_host(user_friendly: bool = False) -> str:
if user_friendly:
return current_app.config["WEBDRIVER_BASEURL_USER_FRIENDLY"]
return current_app.config["WEBDRIVER_BASEURL"]
def headless_url(path: str, user_friendly: bool = False) -> str:
base_url = (
current_app.config["WEBDRIVER_BASEURL_USER_FRIENDLY"]
if user_friendly
else current_app.config["WEBDRIVER_BASEURL"]
)
return urllib.parse.urljoin(base_url, path)
return urllib.parse.urljoin(get_url_host(user_friendly=user_friendly), path)
def get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str:

View File

@ -15,13 +15,16 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import functools
from typing import Any
import pytest
from sqlalchemy.engine import Engine
from unittest.mock import patch
from tests.test_app import app
from superset import db
from superset.extensions import feature_flag_manager
from superset.utils.core import get_example_database, json_dumps_w_dates
@ -108,3 +111,38 @@ def setup_presto_if_needed():
drop_from_schema(engine, ADMIN_SCHEMA_NAME)
engine.execute(f"DROP SCHEMA IF EXISTS {ADMIN_SCHEMA_NAME}")
engine.execute(f"CREATE SCHEMA {ADMIN_SCHEMA_NAME}")
def with_feature_flags(**mock_feature_flags):
"""
Use this decorator to mock feature flags in tests.
Usage:
class TestYourFeature(SupersetTestCase):
@with_feature_flags(YOUR_FEATURE=True)
def test_your_feature_enabled(self):
self.assertEqual(is_feature_enabled("YOUR_FEATURE"), True)
@with_feature_flags(YOUR_FEATURE=False)
def test_your_feature_disabled(self):
self.assertEqual(is_feature_enabled("YOUR_FEATURE"), False)
"""
def mock_get_feature_flags():
feature_flags = feature_flag_manager._feature_flags or {}
return {**feature_flags, **mock_feature_flags}
def decorate(test_fn):
def wrapper(*args, **kwargs):
with patch.object(
feature_flag_manager,
"get_feature_flags",
side_effect=mock_get_feature_flags,
):
test_fn(*args, **kwargs)
return functools.update_wrapper(wrapper, test_fn)
return decorate

View File

@ -35,11 +35,11 @@ from superset.models.reports import (
ReportRecipientType,
ReportState,
)
import tests.test_app
from superset.utils.core import get_example_database
from tests.base_tests import SupersetTestCase
from tests.conftest import with_feature_flags
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
from tests.reports.utils import insert_report_schedule
from superset.utils.core import get_example_database
REPORTS_COUNT = 10
@ -140,6 +140,23 @@ class TestReportSchedulesApi(SupersetTestCase):
db.session.delete(user)
db.session.commit()
@with_feature_flags(ALERT_REPORTS=False)
@pytest.mark.usefixtures("create_report_schedules")
def test_get_report_schedule_disabled(self):
"""
ReportSchedule Api: Test get report schedule 404s when feature is disabled
"""
report_schedule = (
db.session.query(ReportSchedule)
.filter(ReportSchedule.name == "name1")
.first()
)
self.login(username="admin")
uri = f"api/v1/report/{report_schedule.id}"
rv = self.client.get(uri)
assert rv.status_code == 404
@pytest.mark.usefixtures("create_report_schedules")
def test_get_report_schedule(self):
"""

View File

@ -17,18 +17,20 @@
# from superset import db
# from superset.models.dashboard import Dashboard
import urllib.request
from io import BytesIO
from unittest import skipUnless
from unittest.mock import patch
from flask_testing import LiveServerTestCase
from sqlalchemy.sql import func
from superset import db, is_feature_enabled, security_manager, thumbnail_cache
from superset import db, is_feature_enabled, security_manager
from superset.extensions import machine_auth_provider_factory
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
from superset.utils.urls import get_url_path
from superset.utils.urls import get_url_host
from tests.conftest import with_feature_flags
from tests.test_app import app
from .base_tests import SupersetTestCase
@ -63,31 +65,29 @@ class TestThumbnails(SupersetTestCase):
mock_image = b"bytes mock image"
@with_feature_flags(THUMBNAILS=False)
def test_dashboard_thumbnail_disabled(self):
"""
Thumbnails: Dashboard thumbnail disabled
"""
if is_feature_enabled("THUMBNAILS"):
return
dashboard = db.session.query(Dashboard).all()[0]
self.login(username="admin")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
@with_feature_flags(THUMBNAILS=False)
def test_chart_thumbnail_disabled(self):
"""
Thumbnails: Chart thumbnail disabled
"""
if is_feature_enabled("THUMBNAILS"):
return
chart = db.session.query(Slice).all()[0]
self.login(username="admin")
uri = f"api/v1/chart/{chart}/thumbnail/{chart.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
@with_feature_flags(THUMBNAILS=True)
def test_get_async_dashboard_screenshot(self):
"""
Thumbnails: Simple get async dashboard screenshot
@ -100,9 +100,15 @@ class TestThumbnails(SupersetTestCase):
) as mock_task:
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 202)
mock_task.assert_called_with(dashboard.id, force=True)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
expected_uri = f"{get_url_host()}superset/dashboard/{dashboard.id}/"
expected_digest = dashboard.digest
expected_kwargs = {"force": True}
mock_task.assert_called_with(
expected_uri, expected_digest, **expected_kwargs
)
@with_feature_flags(THUMBNAILS=True)
def test_get_async_dashboard_notfound(self):
"""
Thumbnails: Simple get async dashboard not found
@ -124,7 +130,7 @@ class TestThumbnails(SupersetTestCase):
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
@with_feature_flags(THUMBNAILS=True)
def test_get_async_chart_screenshot(self):
"""
Thumbnails: Simple get async chart screenshot
@ -137,9 +143,14 @@ class TestThumbnails(SupersetTestCase):
) as mock_task:
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 202)
mock_task.assert_called_with(chart.id, force=True)
expected_uri = f"{get_url_host()}superset/slice/{chart.id}/?standalone=true"
expected_digest = chart.digest
expected_kwargs = {"force": True}
mock_task.assert_called_with(
expected_uri, expected_digest, **expected_kwargs
)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
@with_feature_flags(THUMBNAILS=True)
def test_get_async_chart_notfound(self):
"""
Thumbnails: Simple get async chart not found
@ -150,68 +161,66 @@ class TestThumbnails(SupersetTestCase):
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
@with_feature_flags(THUMBNAILS=True)
def test_get_cached_chart_wrong_digest(self):
"""
Thumbnails: Simple get chart with wrong digest
"""
chart = db.session.query(Slice).all()[0]
chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true")
# Cache a test "image"
screenshot = ChartScreenshot(chart_url, chart.digest)
thumbnail_cache.set(screenshot.cache_key, self.mock_image)
self.login(username="admin")
uri = f"api/v1/chart/{chart.id}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 302)
self.assertRedirects(rv, f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/")
with patch.object(
ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
):
self.login(username="admin")
uri = f"api/v1/chart/{chart.id}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 302)
self.assertRedirects(
rv, f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/"
)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
@with_feature_flags(THUMBNAILS=True)
def test_get_cached_dashboard_screenshot(self):
"""
Thumbnails: Simple get cached dashboard screenshot
"""
dashboard = db.session.query(Dashboard).all()[0]
dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id)
# Cache a test "image"
screenshot = DashboardScreenshot(dashboard_url, dashboard.digest)
thumbnail_cache.set(screenshot.cache_key, self.mock_image)
self.login(username="admin")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.data, self.mock_image)
with patch.object(
DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
):
self.login(username="admin")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.data, self.mock_image)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
@with_feature_flags(THUMBNAILS=True)
def test_get_cached_chart_screenshot(self):
"""
Thumbnails: Simple get cached chart screenshot
"""
chart = db.session.query(Slice).all()[0]
chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true")
# Cache a test "image"
screenshot = ChartScreenshot(chart_url, chart.digest)
thumbnail_cache.set(screenshot.cache_key, self.mock_image)
self.login(username="admin")
uri = f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.data, self.mock_image)
with patch.object(
ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
):
self.login(username="admin")
uri = f"api/v1/chart/{chart.id}/thumbnail/{chart.digest}/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.data, self.mock_image)
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
@with_feature_flags(THUMBNAILS=True)
def test_get_cached_dashboard_wrong_digest(self):
"""
Thumbnails: Simple get dashboard with wrong digest
"""
dashboard = db.session.query(Dashboard).all()[0]
dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id)
# Cache a test "image"
screenshot = DashboardScreenshot(dashboard_url, dashboard.digest)
thumbnail_cache.set(screenshot.cache_key, self.mock_image)
self.login(username="admin")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 302)
self.assertRedirects(
rv, f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
)
with patch.object(
DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
):
self.login(username="admin")
uri = f"api/v1/dashboard/{dashboard.id}/thumbnail/1234/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 302)
self.assertRedirects(
rv, f"api/v1/dashboard/{dashboard.id}/thumbnail/{dashboard.digest}/"
)