fix(thumbnail cache): Enabling force parameter on screenshot/thumbnail cache (#31757)

Co-authored-by: Kamil Gabryjelski <kamil.gabryjelski@gmail.com>
This commit is contained in:
Jack 2025-01-31 12:22:31 -06:00 committed by GitHub
parent c590e90c87
commit 7db0589340
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 609 additions and 280 deletions

View File

@ -30,7 +30,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
from superset import app, is_feature_enabled, thumbnail_cache
from superset import app, is_feature_enabled
from superset.charts.filters import (
ChartAllTextFilter,
ChartCertifiedFilter,
@ -84,7 +84,12 @@ from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.tasks.utils import get_current_user
from superset.utils import json
from superset.utils.screenshots import ChartScreenshot, DEFAULT_CHART_WINDOW_SIZE
from superset.utils.screenshots import (
ChartScreenshot,
DEFAULT_CHART_WINDOW_SIZE,
ScreenshotCachePayload,
StatusValues,
)
from superset.utils.urls import get_url_path
from superset.views.base_api import (
BaseSupersetModelRestApi,
@ -564,8 +569,14 @@ class ChartRestApi(BaseSupersetModelRestApi):
schema:
$ref: '#/components/schemas/screenshot_query_schema'
responses:
200:
description: Chart async result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartCacheScreenshotResponseSchema"
202:
description: Chart async result
description: Chart screenshot task created
content:
application/json:
schema:
@ -580,6 +591,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500'
"""
rison_dict = kwargs["rison"]
force = rison_dict.get("force")
window_size = rison_dict.get("window_size") or DEFAULT_CHART_WINDOW_SIZE
# Don't shrink the image if thumb_size is not specified
@ -591,25 +603,36 @@ class ChartRestApi(BaseSupersetModelRestApi):
chart_url = get_url_path("Superset.slice", slice_id=chart.id)
screenshot_obj = ChartScreenshot(chart_url, chart.digest)
cache_key = screenshot_obj.cache_key(window_size, thumb_size)
cache_key = screenshot_obj.get_cache_key(window_size, thumb_size)
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
image_url = get_url_path(
"ChartRestApi.screenshot", pk=chart.id, digest=cache_key
)
def trigger_celery() -> WerkzeugResponse:
def build_response(status_code: int) -> WerkzeugResponse:
return self.response(
status_code,
cache_key=cache_key,
chart_url=chart_url,
image_url=image_url,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
if cache_payload.should_trigger_task(force):
logger.info("Triggering screenshot ASYNC")
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_chart_thumbnail.delay(
current_user=get_current_user(),
chart_id=chart.id,
force=True,
window_size=window_size,
thumb_size=thumb_size,
force=force,
)
return self.response(
202, cache_key=cache_key, chart_url=chart_url, image_url=image_url
)
return trigger_celery()
return build_response(202)
return build_response(200)
@expose("/<pk>/screenshot/<digest>/", methods=("GET",))
@protect()
@ -635,7 +658,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
name: digest
responses:
200:
description: Chart thumbnail image
description: Chart screenshot image
content:
image/*:
schema:
@ -652,16 +675,16 @@ class ChartRestApi(BaseSupersetModelRestApi):
"""
chart = self.datamodel.get(pk, self._base_filters)
# Making sure the chart still exists
if not chart:
return self.response_404()
# fetch the chart screenshot using the current user and cache if set
if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest):
return Response(
FileWrapper(img), mimetype="image/png", direct_passthrough=True
)
# TODO: return an empty image
if cache_payload := ChartScreenshot.get_from_cache_key(digest):
if cache_payload.status == StatusValues.UPDATED:
return Response(
FileWrapper(cache_payload.get_image()),
mimetype="image/png",
direct_passthrough=True,
)
return self.response_404()
@expose("/<pk>/thumbnail/<digest>/", methods=("GET",))
@ -685,9 +708,10 @@ class ChartRestApi(BaseSupersetModelRestApi):
type: integer
name: pk
- in: path
name: digest
description: A hex digest that makes this chart unique
schema:
type: string
name: digest
responses:
200:
description: Chart thumbnail image
@ -712,34 +736,6 @@ class ChartRestApi(BaseSupersetModelRestApi):
return self.response_404()
current_user = get_current_user()
url = get_url_path("Superset.slice", slice_id=chart.id)
if kwargs["rison"].get("force", False):
logger.info(
"Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id)
)
cache_chart_thumbnail.delay(
current_user=current_user,
chart_id=chart.id,
force=True,
)
return self.response(202, message="OK Async")
# fetch the chart screenshot using the current user and cache if set
screenshot = ChartScreenshot(url, chart.digest).get_from_cache(
cache=thumbnail_cache
)
# If not screenshot then send request to compute thumb to celery
if not screenshot:
self.incr_stats("async", self.thumbnail.__name__)
logger.info(
"Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id)
)
cache_chart_thumbnail.delay(
current_user=current_user,
chart_id=chart.id,
force=True,
)
return self.response(202, message="OK Async")
# If digests
if chart.digest != digest:
self.incr_stats("redirect", self.thumbnail.__name__)
return redirect(
@ -747,9 +743,34 @@ class ChartRestApi(BaseSupersetModelRestApi):
f"{self.__class__.__name__}.thumbnail", pk=pk, digest=chart.digest
)
)
url = get_url_path("Superset.slice", slice_id=chart.id)
screenshot_obj = ChartScreenshot(url, chart.digest)
cache_key = screenshot_obj.get_cache_key()
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
if cache_payload.should_trigger_task():
self.incr_stats("async", self.thumbnail.__name__)
logger.info(
"Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id)
)
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_chart_thumbnail.delay(
current_user=current_user,
chart_id=chart.id,
force=False,
)
return self.response(
202,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
self.incr_stats("from_cache", self.thumbnail.__name__)
return Response(
FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True
FileWrapper(cache_payload.get_image()),
mimetype="image/png",
direct_passthrough=True,
)
@expose("/export/", methods=("GET",))

View File

@ -304,6 +304,21 @@ class ChartCacheScreenshotResponseSchema(Schema):
image_url = fields.String(
metadata={"description": "The url to fetch the screenshot"}
)
task_status = fields.String(
metadata={"description": "The status of the async screenshot"}
)
task_updated_at = fields.String(
metadata={"description": "The timestamp of the last change in status"}
)
class ChartGetCachedScreenshotResponseSchema(Schema):
task_status = fields.String(
metadata={"description": "The status of the async screenshot"}
)
task_updated_at = fields.String(
metadata={"description": "The timestamp of the last change in status"}
)
class ChartDataColumnSchema(Schema):

View File

@ -729,8 +729,10 @@ THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str | None] |
THUMBNAIL_CACHE_CONFIG: CacheConfig = {
"CACHE_TYPE": "NullCache",
"CACHE_DEFAULT_TIMEOUT": int(timedelta(days=7).total_seconds()),
"CACHE_NO_NULL_WARNING": True,
}
THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds())
# Time before selenium times out after trying to locate an element on the page and wait
# for that element to load for a screenshot.

View File

@ -31,7 +31,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
from superset import db, thumbnail_cache
from superset import db
from superset.charts.schemas import ChartEntityResponseSchema
from superset.commands.dashboard.copy import CopyDashboardCommand
from superset.commands.dashboard.create import CreateDashboardCommand
@ -115,6 +115,7 @@ from superset.utils.pdf import build_pdf_from_screenshots
from superset.utils.screenshots import (
DashboardScreenshot,
DEFAULT_DASHBOARD_WINDOW_SIZE,
ScreenshotCachePayload,
)
from superset.utils.urls import get_url_path
from superset.views.base_api import (
@ -1022,110 +1023,6 @@ class DashboardRestApi(BaseSupersetModelRestApi):
response.set_cookie(token, "done", max_age=600)
return response
@expose("/<pk>/thumbnail/<digest>/", methods=("GET",))
@validate_feature_flags(["THUMBNAILS"])
@protect()
@safe
@rison(thumbnail_query_schema)
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.thumbnail",
log_to_statsd=False,
)
def thumbnail(self, pk: int, digest: str, **kwargs: Any) -> WerkzeugResponse:
"""Compute async or get already computed dashboard thumbnail from cache.
---
get:
summary: Get dashboard's thumbnail
description: >-
Computes async or get already computed dashboard thumbnail from cache.
parameters:
- in: path
schema:
type: integer
name: pk
- in: path
name: digest
description: A hex digest that makes this dashboard unique
schema:
type: string
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/thumbnail_query_schema'
responses:
200:
description: Dashboard thumbnail image
content:
image/*:
schema:
type: string
format: binary
202:
description: Thumbnail does not exist on cache, fired async to compute
content:
application/json:
schema:
type: object
properties:
message:
type: string
302:
description: Redirects to the current digest
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters))
if not dashboard:
return self.response_404()
dashboard_url = get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
# If force, request a screenshot from the workers
current_user = get_current_user()
if kwargs["rison"].get("force", False):
cache_dashboard_thumbnail.delay(
current_user=current_user,
dashboard_id=dashboard.id,
force=True,
)
return self.response(202, message="OK Async")
# fetch the dashboard screenshot using the current user and cache if set
screenshot = DashboardScreenshot(
dashboard_url, dashboard.digest
).get_from_cache(cache=thumbnail_cache)
# If the screenshot does not exist, request one from the workers
if not screenshot:
self.incr_stats("async", self.thumbnail.__name__)
cache_dashboard_thumbnail.delay(
current_user=current_user,
dashboard_id=dashboard.id,
force=True,
)
return self.response(202, message="OK Async")
# If digests
if dashboard.digest != digest:
self.incr_stats("redirect", self.thumbnail.__name__)
return redirect(
url_for(
f"{self.__class__.__name__}.thumbnail",
pk=pk,
digest=dashboard.digest,
)
)
self.incr_stats("from_cache", self.thumbnail.__name__)
return Response(
FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True
)
@expose("/<pk>/cache_dashboard_screenshot/", methods=("POST",))
@validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"])
@protect()
@ -1172,7 +1069,6 @@ class DashboardRestApi(BaseSupersetModelRestApi):
payload = CacheScreenshotSchema().load(request.json)
except ValidationError as error:
return self.response_400(message=error.messages)
dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters))
if not dashboard:
return self.response_404()
@ -1182,7 +1078,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
)
# Don't shrink the image if thumb_size is not specified
thumb_size = kwargs["rison"].get("thumb_size") or window_size
force = kwargs["rison"].get("force", False)
dashboard_state: DashboardPermalinkState = {
"dataMask": payload.get("dataMask", {}),
"activeTabs": payload.get("activeTabs", []),
@ -1197,13 +1093,29 @@ class DashboardRestApi(BaseSupersetModelRestApi):
dashboard_url = get_url_path("Superset.dashboard_permalink", key=permalink_key)
screenshot_obj = DashboardScreenshot(dashboard_url, dashboard.digest)
cache_key = screenshot_obj.cache_key(window_size, thumb_size, dashboard_state)
cache_key = screenshot_obj.get_cache_key(
window_size, thumb_size, dashboard_state
)
image_url = get_url_path(
"DashboardRestApi.screenshot", pk=dashboard.id, digest=cache_key
)
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
def trigger_celery() -> WerkzeugResponse:
def build_response(status_code: int) -> WerkzeugResponse:
return self.response(
status_code,
cache_key=cache_key,
dashboard_url=dashboard_url,
image_url=image_url,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
if cache_payload.should_trigger_task(force):
logger.info("Triggering screenshot ASYNC")
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_dashboard_screenshot.delay(
username=get_current_user(),
guest_token=(
@ -1213,19 +1125,12 @@ class DashboardRestApi(BaseSupersetModelRestApi):
),
dashboard_id=dashboard.id,
dashboard_url=dashboard_url,
cache_key=cache_key,
force=False,
thumb_size=thumb_size,
window_size=window_size,
force=force,
)
return self.response(
202,
cache_key=cache_key,
dashboard_url=dashboard_url,
image_url=image_url,
)
return trigger_celery()
return build_response(202)
return build_response(200)
@expose("/<pk>/screenshot/<digest>/", methods=("GET",))
@validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"])
@ -1282,9 +1187,12 @@ class DashboardRestApi(BaseSupersetModelRestApi):
# fetch the dashboard screenshot using the current user and cache if set
if img := DashboardScreenshot.get_from_cache_key(thumbnail_cache, digest):
if cache_payload := DashboardScreenshot.get_from_cache_key(digest):
image = cache_payload.get_image()
if not image:
return self.response_404()
if download_format == "pdf":
pdf_img = img.getvalue()
pdf_img = image.getvalue()
# Convert the screenshot to PDF
pdf_data = build_pdf_from_screenshots([pdf_img])
@ -1296,13 +1204,120 @@ class DashboardRestApi(BaseSupersetModelRestApi):
)
if download_format == "png":
return Response(
FileWrapper(img),
FileWrapper(image),
mimetype="image/png",
direct_passthrough=True,
)
return self.response_404()
@expose("/<pk>/thumbnail/<digest>/", methods=("GET",))
@validate_feature_flags(["THUMBNAILS"])
@protect()
@safe
@rison(thumbnail_query_schema)
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.thumbnail",
log_to_statsd=False,
)
def thumbnail(self, pk: int, digest: str, **kwargs: Any) -> WerkzeugResponse:
"""Compute async or get already computed dashboard thumbnail from cache.
---
get:
summary: Get dashboard's thumbnail
description: >-
Computes async or get already computed dashboard thumbnail from cache.
parameters:
- in: path
schema:
type: integer
name: pk
- in: path
name: digest
description: A hex digest that makes this dashboard unique
schema:
type: string
responses:
200:
description: Dashboard thumbnail image
content:
image/*:
schema:
type: string
format: binary
202:
description: Thumbnail does not exist on cache, fired async to compute
content:
application/json:
schema:
type: object
properties:
message:
type: string
302:
description: Redirects to the current digest
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters))
if not dashboard:
return self.response_404()
current_user = get_current_user()
dashboard_url = get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
if dashboard.digest != digest:
self.incr_stats("redirect", self.thumbnail.__name__)
return redirect(
url_for(
f"{self.__class__.__name__}.thumbnail",
pk=pk,
digest=dashboard.digest,
)
)
screenshot_obj = DashboardScreenshot(dashboard_url, dashboard.digest)
cache_key = screenshot_obj.get_cache_key()
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
image_url = get_url_path(
"DashboardRestApi.thumbnail", pk=dashboard.id, digest=cache_key
)
if cache_payload.should_trigger_task():
self.incr_stats("async", self.thumbnail.__name__)
logger.info(
"Triggering thumbnail compute (dashboard id: %s) ASYNC",
str(dashboard.id),
)
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_dashboard_thumbnail.delay(
current_user=current_user,
dashboard_id=dashboard.id,
force=False,
)
return self.response(
202,
cache_key=cache_key,
dashboard_url=dashboard_url,
image_url=image_url,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
self.incr_stats("from_cache", self.thumbnail.__name__)
return Response(
FileWrapper(cache_payload.get_image()),
mimetype="image/png",
direct_passthrough=True,
)
@expose("/favorite_status/", methods=("GET",))
@protect()
@safe

View File

@ -507,6 +507,12 @@ class DashboardCacheScreenshotResponseSchema(Schema):
image_url = fields.String(
metadata={"description": "The url to fetch the screenshot"}
)
task_status = fields.String(
metadata={"description": "The status of the async screenshot"}
)
task_updated_at = fields.String(
metadata={"description": "The timestamp of the last change in status"}
)
class CacheScreenshotSchema(Schema):

View File

@ -380,9 +380,7 @@ def event_after_chart_changed(
_mapper: Mapper, _connection: Connection, target: Slice
) -> None:
cache_chart_thumbnail.delay(
current_user=get_current_user(),
chart_id=target.id,
force=True,
current_user=get_current_user(), chart_id=target.id, force=True
)

View File

@ -38,7 +38,7 @@ logger = logging.getLogger(__name__)
def cache_chart_thumbnail(
current_user: Optional[str],
chart_id: int,
force: bool = False,
force: bool,
window_size: Optional[WindowSize] = None,
thumb_size: Optional[WindowSize] = None,
) -> None:
@ -64,10 +64,9 @@ def cache_chart_thumbnail(
screenshot = ChartScreenshot(url, chart.digest)
screenshot.compute_and_cache(
user=user,
cache=thumbnail_cache,
force=force,
window_size=window_size,
thumb_size=thumb_size,
force=force,
)
return None
@ -76,7 +75,7 @@ def cache_chart_thumbnail(
def cache_dashboard_thumbnail(
current_user: Optional[str],
dashboard_id: int,
force: bool = False,
force: bool,
thumb_size: Optional[WindowSize] = None,
window_size: Optional[WindowSize] = None,
) -> None:
@ -101,10 +100,9 @@ def cache_dashboard_thumbnail(
screenshot = DashboardScreenshot(url, dashboard.digest)
screenshot.compute_and_cache(
user=user,
cache=thumbnail_cache,
force=force,
window_size=window_size,
thumb_size=thumb_size,
force=force,
)
@ -113,7 +111,7 @@ def cache_dashboard_screenshot( # pylint: disable=too-many-arguments
username: str,
dashboard_id: int,
dashboard_url: str,
force: bool = True,
force: bool,
cache_key: Optional[str] = None,
guest_token: Optional[GuestToken] = None,
thumb_size: Optional[WindowSize] = None,
@ -145,9 +143,8 @@ def cache_dashboard_screenshot( # pylint: disable=too-many-arguments
screenshot = DashboardScreenshot(dashboard_url, dashboard.digest)
screenshot.compute_and_cache(
user=current_user,
cache=thumbnail_cache,
force=force,
window_size=window_size,
thumb_size=thumb_size,
cache_key=cache_key,
force=force,
)

View File

@ -17,12 +17,14 @@
from __future__ import annotations
import logging
from datetime import datetime
from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING
from flask import current_app
from superset import feature_flag_manager
from superset import app, feature_flag_manager, thumbnail_cache
from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.extensions import event_logger
from superset.utils.hashing import md5_sha_from_dict
@ -54,6 +56,70 @@ if TYPE_CHECKING:
from flask_caching import Cache
class StatusValues(Enum):
PENDING = "Pending"
COMPUTING = "Computing"
UPDATED = "Updated"
ERROR = "Error"
class ScreenshotCachePayload:
def __init__(self, image: bytes | None = None):
self._image = image
self._timestamp = datetime.now().isoformat()
self.status = StatusValues.PENDING
if image:
self.status = StatusValues.UPDATED
def update_timestamp(self) -> None:
self._timestamp = datetime.now().isoformat()
def pending(self) -> None:
self.update_timestamp()
self._image = None
self.status = StatusValues.PENDING
def computing(self) -> None:
self.update_timestamp()
self._image = None
self.status = StatusValues.COMPUTING
def update(self, image: bytes) -> None:
self.update_timestamp()
self.status = StatusValues.UPDATED
self._image = image
def error(
self,
) -> None:
self.update_timestamp()
self.status = StatusValues.ERROR
def get_image(self) -> BytesIO | None:
if not self._image:
return None
return BytesIO(self._image)
def get_timestamp(self) -> str:
return self._timestamp
def get_status(self) -> str:
return self.status.value
def is_error_cache_ttl_expired(self) -> bool:
error_cache_ttl = app.config["THUMBNAIL_ERROR_CACHE_TTL"]
return (
datetime.now() - datetime.fromisoformat(self.get_timestamp())
).total_seconds() > error_cache_ttl
def should_trigger_task(self, force: bool = False) -> bool:
return (
force
or self.status == StatusValues.PENDING
or (self.status == StatusValues.ERROR and self.is_error_cache_ttl_expired())
)
class BaseScreenshot:
driver_type = current_app.config["WEBDRIVER_TYPE"]
url: str
@ -63,6 +129,7 @@ class BaseScreenshot:
element: str = ""
window_size: WindowSize = DEFAULT_SCREENSHOT_WINDOW_SIZE
thumb_size: WindowSize = DEFAULT_SCREENSHOT_THUMBNAIL_SIZE
cache: Cache = thumbnail_cache
def __init__(self, url: str, digest: str | None):
self.digest = digest
@ -75,7 +142,14 @@ class BaseScreenshot:
return WebDriverPlaywright(self.driver_type, window_size)
return WebDriverSelenium(self.driver_type, window_size)
def cache_key(
def get_screenshot(
self, user: User, window_size: WindowSize | None = None
) -> bytes | None:
driver = self.driver(window_size)
self.screenshot = driver.get_screenshot(self.url, self.element, user)
return self.screenshot
def get_cache_key(
self,
window_size: bool | WindowSize | None = None,
thumb_size: bool | WindowSize | None = None,
@ -91,69 +165,35 @@ class BaseScreenshot:
}
return md5_sha_from_dict(args)
def get_screenshot(
self, user: User, window_size: WindowSize | None = None
) -> bytes | None:
driver = self.driver(window_size)
with event_logger.log_context("screenshot", screenshot_url=self.url):
self.screenshot = driver.get_screenshot(self.url, self.element, user)
return self.screenshot
def get(
self,
user: User = None,
cache: Cache = None,
thumb_size: WindowSize | None = None,
) -> BytesIO | None:
"""
Get thumbnail screenshot has BytesIO from cache or fetch
:param user: None to use current user or User Model to login and fetch
:param cache: The cache to use
:param thumb_size: Override thumbnail site
"""
payload: bytes | None = None
cache_key = self.cache_key(self.window_size, thumb_size)
if cache:
payload = cache.get(cache_key)
if not payload:
payload = self.compute_and_cache(
user=user, thumb_size=thumb_size, cache=cache
)
else:
logger.info("Loaded thumbnail from cache: %s", cache_key)
if payload:
return BytesIO(payload)
return None
def get_from_cache(
self,
cache: Cache,
window_size: WindowSize | None = None,
thumb_size: WindowSize | None = None,
) -> BytesIO | None:
cache_key = self.cache_key(window_size, thumb_size)
return self.get_from_cache_key(cache, cache_key)
) -> ScreenshotCachePayload | None:
cache_key = self.get_cache_key(window_size, thumb_size)
return self.get_from_cache_key(cache_key)
@staticmethod
def get_from_cache_key(cache: Cache, cache_key: str) -> BytesIO | None:
@classmethod
def get_from_cache_key(cls, cache_key: str) -> ScreenshotCachePayload | None:
logger.info("Attempting to get from cache: %s", cache_key)
if payload := cache.get(cache_key):
return BytesIO(payload)
if payload := cls.cache.get(cache_key):
# for backwards compatability, byte objects should be converted
if not isinstance(payload, ScreenshotCachePayload):
payload = ScreenshotCachePayload(payload)
return payload
logger.info("Failed at getting from cache: %s", cache_key)
return None
def compute_and_cache( # pylint: disable=too-many-arguments
self,
force: bool,
user: User = None,
window_size: WindowSize | None = None,
thumb_size: WindowSize | None = None,
cache: Cache = None,
force: bool = True,
cache_key: str | None = None,
) -> bytes | None:
) -> None:
"""
Fetches the screenshot, computes the thumbnail and caches the result
Computes the thumbnail and caches the result
:param user: If no user is given will use the current context
:param cache: The cache to keep the thumbnail payload
@ -162,40 +202,46 @@ class BaseScreenshot:
:param force: Will force the computation even if it's already cached
:return: Image payload
"""
cache_key = cache_key or self.cache_key(window_size, thumb_size)
cache_key = cache_key or self.get_cache_key(window_size, thumb_size)
cache_payload = self.get_from_cache_key(cache_key) or ScreenshotCachePayload()
if (
cache_payload.status in [StatusValues.COMPUTING, StatusValues.UPDATED]
and not force
):
logger.info(
"Skipping compute - already processed for thumbnail: %s", cache_key
)
return
window_size = window_size or self.window_size
thumb_size = thumb_size or self.thumb_size
if not force and cache and cache.get(cache_key):
logger.info("Thumb already cached, skipping...")
return None
logger.info("Processing url for thumbnail: %s", cache_key)
payload = None
cache_payload.computing()
self.cache.set(cache_key, cache_payload)
image = None
# Assuming all sorts of things can go wrong with Selenium
try:
with event_logger.log_context(
f"screenshot.compute.{self.thumbnail_type}", force=force
):
payload = self.get_screenshot(user=user, window_size=window_size)
logger.info("trying to generate screenshot")
with event_logger.log_context(f"screenshot.compute.{self.thumbnail_type}"):
image = self.get_screenshot(user=user, window_size=window_size)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed at generating thumbnail %s", ex, exc_info=True)
if payload and window_size != thumb_size:
cache_payload.error()
if image and window_size != thumb_size:
try:
payload = self.resize_image(payload, thumb_size=thumb_size)
image = self.resize_image(image, thumb_size=thumb_size)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed at resizing thumbnail %s", ex, exc_info=True)
payload = None
cache_payload.error()
image = None
if payload:
if image:
logger.info("Caching thumbnail: %s", cache_key)
with event_logger.log_context(
f"screenshot.cache.{self.thumbnail_type}", force=force
):
cache.set(cache_key, payload)
logger.info("Done caching thumbnail")
return payload
with event_logger.log_context(f"screenshot.cache.{self.thumbnail_type}"):
cache_payload.update(image)
self.cache.set(cache_key, cache_payload)
logger.info("Updated thumbnail cache; Status: %s", cache_payload.get_status())
return
@classmethod
def resize_image(
@ -265,7 +311,7 @@ class DashboardScreenshot(BaseScreenshot):
self.window_size = window_size or DEFAULT_DASHBOARD_WINDOW_SIZE
self.thumb_size = thumb_size or DEFAULT_DASHBOARD_THUMBNAIL_SIZE
def cache_key(
def get_cache_key(
self,
window_size: bool | WindowSize | None = None,
thumb_size: bool | WindowSize | None = None,

View File

@ -380,7 +380,7 @@ class WebDriverSelenium(WebDriverProxy):
return error_messages
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None:
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: # noqa: C901
driver = self.auth(user)
driver.set_window_size(*self._window)
driver.get(url)
@ -411,6 +411,7 @@ class WebDriverSelenium(WebDriverProxy):
)
)
except TimeoutException:
logger.info("Timeout Exception caught")
# Fallback to allow a screenshot of an empty dashboard
try:
WebDriverWait(driver, 0).until(
@ -461,18 +462,23 @@ class WebDriverSelenium(WebDriverProxy):
)
img = element.screenshot_as_png
except Exception as ex:
logger.warning("exception in webdriver", exc_info=ex)
raise
except TimeoutException:
# raise again for the finally block, but handled above
pass
raise
except StaleElementReferenceException:
logger.exception(
"Selenium got a stale element while requesting url %s",
url,
)
raise
except WebDriverException:
logger.exception(
"Encountered an unexpected error when requesting url %s", url
)
raise
finally:
self.destroy(driver, current_app.config["SCREENSHOT_SELENIUM_RETRIES"])
return img

View File

@ -319,9 +319,5 @@ def test_compute_thumbnails(thumbnail_mock, app_context, fs):
["-d", "-i", dashboard.id],
)
thumbnail_mock.assert_called_with(
None,
dashboard.id,
force=False,
)
thumbnail_mock.assert_called_with(None, dashboard.id, force=False)
assert response.exit_code == 0

View File

@ -37,6 +37,7 @@ from superset.reports.models import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice
from superset.tags.models import Tag, TaggedObject, TagType, ObjectType
from superset.utils.core import backend, override_user
from superset.utils.screenshots import ScreenshotCachePayload
from superset.utils import json
from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
@ -3069,13 +3070,15 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
@pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot")
@patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key")
def test_screenshot_success_png(self, mock_get_cache, mock_cache_task):
def test_screenshot_success_png(self, mock_get_from_cache_key, mock_cache_task):
"""
Validate screenshot returns png
"""
self.login(ADMIN_USERNAME)
mock_cache_task.return_value = None
mock_get_cache.return_value = BytesIO(b"fake image data")
mock_get_from_cache_key.return_value = ScreenshotCachePayload(
b"fake image data"
)
dashboard = (
db.session.query(Dashboard)
@ -3083,7 +3086,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
.first()
)
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
assert cache_resp.status_code == 200
cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"]
response = self._get_screenshot(dashboard.id, cache_key, "png")
@ -3091,20 +3094,29 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
assert response.mimetype == "image/png"
assert response.data == b"fake image data"
mock_get_from_cache_key.return_value = ScreenshotCachePayload()
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
@with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True)
@pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot")
@patch("superset.dashboards.api.build_pdf_from_screenshots")
@patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key")
def test_screenshot_success_pdf(
self, mock_get_from_cache, mock_build_pdf, mock_cache_task
self,
mock_get_from_cache_key,
mock_build_pdf,
mock_cache_task,
):
"""
Validate screenshot can return pdf.
"""
self.login(ADMIN_USERNAME)
mock_cache_task.return_value = None
mock_get_from_cache.return_value = BytesIO(b"fake image data")
mock_get_from_cache_key.return_value = ScreenshotCachePayload(
b"fake image data"
)
mock_build_pdf.return_value = b"fake pdf data"
dashboard = (
@ -3113,7 +3125,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
.first()
)
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
assert cache_resp.status_code == 200
cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"]
response = self._get_screenshot(dashboard.id, cache_key, "pdf")
@ -3121,6 +3133,10 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
assert response.mimetype == "application/pdf"
assert response.data == b"fake pdf data"
mock_get_from_cache_key.return_value = ScreenshotCachePayload()
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
@with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True)
@pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot")
@ -3153,10 +3169,12 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
@pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot")
@patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key")
def test_screenshot_invalid_download_format(self, mock_get_cache, mock_cache_task):
def test_screenshot_invalid_download_format(
self, mock_get_from_cache_key, mock_cache_task
):
self.login(ADMIN_USERNAME)
mock_cache_task.return_value = None
mock_get_cache.return_value = BytesIO(b"fake png data")
mock_get_from_cache_key.return_value = ScreenshotCachePayload(b"fake png data")
dashboard = (
db.session.query(Dashboard)
@ -3165,9 +3183,13 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
)
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
assert cache_resp.status_code == 200
cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"]
mock_get_from_cache_key.return_value = ScreenshotCachePayload()
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
response = self._get_screenshot(dashboard.id, cache_key, "invalid")
assert response.status_code == 404

View File

@ -18,7 +18,6 @@
# from superset.models.dashboard import Dashboard
import urllib.request
from io import BytesIO
from unittest import skipUnless
from unittest.mock import ANY, call, MagicMock, patch
@ -32,7 +31,11 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tasks.types import ExecutorType, FixedExecutor
from superset.utils import json
from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
from superset.utils.screenshots import (
ChartScreenshot,
DashboardScreenshot,
ScreenshotCachePayload,
)
from superset.utils.urls import get_url_path
from superset.utils.webdriver import WebDriverSelenium
from tests.integration_tests.base_tests import SupersetTestCase
@ -287,14 +290,14 @@ class TestThumbnails(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_not_allowed(self):
def test_get_async_dashboard_created(self):
"""
Thumbnails: Simple get async dashboard not allowed
"""
self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
rv = self.client.get(thumbnail_url)
assert rv.status_code == 404
assert rv.status_code == 202
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True)
@ -370,7 +373,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get chart with wrong digest
"""
with patch.object(
ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
ChartScreenshot,
"get_from_cache",
return_value=ScreenshotCachePayload(self.mock_image),
):
self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
@ -385,7 +390,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get cached dashboard screenshot
"""
with patch.object(
DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
DashboardScreenshot,
"get_from_cache_key",
return_value=ScreenshotCachePayload(self.mock_image),
):
self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
@ -400,7 +407,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get cached chart screenshot
"""
with patch.object(
ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
ChartScreenshot,
"get_from_cache_key",
return_value=ScreenshotCachePayload(self.mock_image),
):
self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
@ -415,7 +424,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get dashboard with wrong digest
"""
with patch.object(
DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image)
DashboardScreenshot,
"get_from_cache",
return_value=ScreenshotCachePayload(self.mock_image),
):
self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)

View File

@ -0,0 +1,194 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.utils.hashing import md5_sha_from_dict
from superset.utils.screenshots import (
BaseScreenshot,
ScreenshotCachePayload,
StatusValues,
)
BASE_SCREENSHOT_PATH = "superset.utils.screenshots.BaseScreenshot"
class MockCache:
"""A class to manage screenshot cache."""
def __init__(self):
self._cache = None # Store the cached value
def set(self, _key, value):
"""Set the cache with a new value."""
self._cache = value
def get(self, _key):
"""Get the cached value."""
return self._cache
@pytest.fixture
def mock_user():
"""Fixture to create a mock user."""
mock_user = MagicMock()
mock_user.id = 1
return mock_user
@pytest.fixture
def screenshot_obj():
"""Fixture to create a BaseScreenshot object."""
url = "http://example.com"
digest = "sample_digest"
return BaseScreenshot(url, digest)
def test_get_screenshot(mocker: MockerFixture, screenshot_obj):
"""Get screenshot should return a Bytes object"""
fake_bytes = b"fake_screenshot_data"
driver = mocker.patch(BASE_SCREENSHOT_PATH + ".driver")
driver.return_value.get_screenshot.return_value = fake_bytes
screenshot_data = screenshot_obj.get_screenshot(mock_user)
assert screenshot_data == fake_bytes
def test_get_cache_key(screenshot_obj):
"""Test get_cache_key method"""
expected_cache_key = md5_sha_from_dict(
{
"thumbnail_type": "",
"digest": screenshot_obj.digest,
"type": "thumb",
"window_size": screenshot_obj.window_size,
"thumb_size": screenshot_obj.thumb_size,
}
)
cache_key = screenshot_obj.get_cache_key()
assert cache_key == expected_cache_key
def test_get_from_cache_key(mocker: MockerFixture, screenshot_obj):
"""get_from_cache_key should always return a ScreenshotCachePayload Object"""
# backwards compatability test for retrieving plain bytes
fake_bytes = b"fake_screenshot_data"
BaseScreenshot.cache = MockCache()
BaseScreenshot.cache.set("key", fake_bytes)
cache_payload = screenshot_obj.get_from_cache_key("key")
assert isinstance(cache_payload, ScreenshotCachePayload)
assert cache_payload._image == fake_bytes # pylint: disable=protected-access
class TestComputeAndCache:
def _setup_compute_and_cache(self, mocker: MockerFixture, screenshot_obj):
"""Helper method to handle the common setup for the tests."""
# Patch the methods
get_from_cache_key = mocker.patch(
BASE_SCREENSHOT_PATH + ".get_from_cache_key", return_value=None
)
get_screenshot = mocker.patch(
BASE_SCREENSHOT_PATH + ".get_screenshot", return_value=b"new_image_data"
)
resize_image = mocker.patch(
BASE_SCREENSHOT_PATH + ".resize_image", return_value=b"resized_image_data"
)
BaseScreenshot.cache = MockCache()
return {
"get_from_cache_key": get_from_cache_key,
"get_screenshot": get_screenshot,
"resize_image": resize_image,
}
def test_happy_path(self, mocker: MockerFixture, screenshot_obj):
self._setup_compute_and_cache(mocker, screenshot_obj)
screenshot_obj.compute_and_cache(force=False)
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.UPDATED
def test_screenshot_error(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
get_screenshot: MagicMock = mocks.get("get_screenshot")
get_screenshot.side_effect = Exception
screenshot_obj.compute_and_cache(force=False)
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.ERROR
def test_resize_error(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
resize_image: MagicMock = mocks.get("resize_image")
resize_image.side_effect = Exception
screenshot_obj.compute_and_cache(force=False)
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.ERROR
def test_skips_if_computing(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
cached_value = ScreenshotCachePayload()
cached_value.computing()
get_from_cache_key = mocks.get("get_from_cache_key")
get_from_cache_key.return_value = cached_value
# Ensure that it skips when thumbnail status is computing
screenshot_obj.compute_and_cache(force=False)
get_screenshot = mocks.get("get_screenshot")
get_screenshot.assert_not_called()
# Ensure that it processes when force = True
screenshot_obj.compute_and_cache(force=True)
get_screenshot.assert_called_once()
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.UPDATED
def test_skips_if_updated(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
cached_value = ScreenshotCachePayload(image=b"initial_value")
get_from_cache_key = mocks.get("get_from_cache_key")
get_from_cache_key.return_value = cached_value
# Ensure that it skips when thumbnail status is updated
window_size = thumb_size = (10, 10)
screenshot_obj.compute_and_cache(
force=False, window_size=window_size, thumb_size=thumb_size
)
get_screenshot = mocks.get("get_screenshot")
get_screenshot.assert_not_called()
# Ensure that it processes when force = True
screenshot_obj.compute_and_cache(
force=True, window_size=window_size, thumb_size=thumb_size
)
get_screenshot.assert_called_once()
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload._image != b"initial_value"
def test_resize(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
window_size = thumb_size = (10, 10)
resize_image: MagicMock = mocks.get("resize_image")
screenshot_obj.compute_and_cache(
force=False, window_size=window_size, thumb_size=thumb_size
)
resize_image.assert_not_called()
screenshot_obj.compute_and_cache(
force=False, window_size=(1, 1), thumb_size=thumb_size
)
resize_image.assert_called_once()