fix(thumbnail cache): Enabling force parameter on screenshot/thumbnail cache (#31757)

Co-authored-by: Kamil Gabryjelski <kamil.gabryjelski@gmail.com>
This commit is contained in:
Jack 2025-01-31 12:22:31 -06:00 committed by GitHub
parent c590e90c87
commit 7db0589340
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 609 additions and 280 deletions

View File

@ -30,7 +30,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper from werkzeug.wsgi import FileWrapper
from superset import app, is_feature_enabled, thumbnail_cache from superset import app, is_feature_enabled
from superset.charts.filters import ( from superset.charts.filters import (
ChartAllTextFilter, ChartAllTextFilter,
ChartCertifiedFilter, ChartCertifiedFilter,
@ -84,7 +84,12 @@ from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.tasks.utils import get_current_user from superset.tasks.utils import get_current_user
from superset.utils import json from superset.utils import json
from superset.utils.screenshots import ChartScreenshot, DEFAULT_CHART_WINDOW_SIZE from superset.utils.screenshots import (
ChartScreenshot,
DEFAULT_CHART_WINDOW_SIZE,
ScreenshotCachePayload,
StatusValues,
)
from superset.utils.urls import get_url_path from superset.utils.urls import get_url_path
from superset.views.base_api import ( from superset.views.base_api import (
BaseSupersetModelRestApi, BaseSupersetModelRestApi,
@ -564,8 +569,14 @@ class ChartRestApi(BaseSupersetModelRestApi):
schema: schema:
$ref: '#/components/schemas/screenshot_query_schema' $ref: '#/components/schemas/screenshot_query_schema'
responses: responses:
200:
description: Chart async result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartCacheScreenshotResponseSchema"
202: 202:
description: Chart async result description: Chart screenshot task created
content: content:
application/json: application/json:
schema: schema:
@ -580,6 +591,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
rison_dict = kwargs["rison"] rison_dict = kwargs["rison"]
force = rison_dict.get("force")
window_size = rison_dict.get("window_size") or DEFAULT_CHART_WINDOW_SIZE window_size = rison_dict.get("window_size") or DEFAULT_CHART_WINDOW_SIZE
# Don't shrink the image if thumb_size is not specified # Don't shrink the image if thumb_size is not specified
@ -591,25 +603,36 @@ class ChartRestApi(BaseSupersetModelRestApi):
chart_url = get_url_path("Superset.slice", slice_id=chart.id) chart_url = get_url_path("Superset.slice", slice_id=chart.id)
screenshot_obj = ChartScreenshot(chart_url, chart.digest) screenshot_obj = ChartScreenshot(chart_url, chart.digest)
cache_key = screenshot_obj.cache_key(window_size, thumb_size) cache_key = screenshot_obj.get_cache_key(window_size, thumb_size)
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
image_url = get_url_path( image_url = get_url_path(
"ChartRestApi.screenshot", pk=chart.id, digest=cache_key "ChartRestApi.screenshot", pk=chart.id, digest=cache_key
) )
def trigger_celery() -> WerkzeugResponse: def build_response(status_code: int) -> WerkzeugResponse:
return self.response(
status_code,
cache_key=cache_key,
chart_url=chart_url,
image_url=image_url,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
if cache_payload.should_trigger_task(force):
logger.info("Triggering screenshot ASYNC") logger.info("Triggering screenshot ASYNC")
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_chart_thumbnail.delay( cache_chart_thumbnail.delay(
current_user=get_current_user(), current_user=get_current_user(),
chart_id=chart.id, chart_id=chart.id,
force=True,
window_size=window_size, window_size=window_size,
thumb_size=thumb_size, thumb_size=thumb_size,
force=force,
) )
return self.response( return build_response(202)
202, cache_key=cache_key, chart_url=chart_url, image_url=image_url return build_response(200)
)
return trigger_celery()
@expose("/<pk>/screenshot/<digest>/", methods=("GET",)) @expose("/<pk>/screenshot/<digest>/", methods=("GET",))
@protect() @protect()
@ -635,7 +658,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
name: digest name: digest
responses: responses:
200: 200:
description: Chart thumbnail image description: Chart screenshot image
content: content:
image/*: image/*:
schema: schema:
@ -652,16 +675,16 @@ class ChartRestApi(BaseSupersetModelRestApi):
""" """
chart = self.datamodel.get(pk, self._base_filters) chart = self.datamodel.get(pk, self._base_filters)
# Making sure the chart still exists
if not chart: if not chart:
return self.response_404() return self.response_404()
# fetch the chart screenshot using the current user and cache if set if cache_payload := ChartScreenshot.get_from_cache_key(digest):
if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest): if cache_payload.status == StatusValues.UPDATED:
return Response( return Response(
FileWrapper(img), mimetype="image/png", direct_passthrough=True FileWrapper(cache_payload.get_image()),
) mimetype="image/png",
# TODO: return an empty image direct_passthrough=True,
)
return self.response_404() return self.response_404()
@expose("/<pk>/thumbnail/<digest>/", methods=("GET",)) @expose("/<pk>/thumbnail/<digest>/", methods=("GET",))
@ -685,9 +708,10 @@ class ChartRestApi(BaseSupersetModelRestApi):
type: integer type: integer
name: pk name: pk
- in: path - in: path
name: digest
description: A hex digest that makes this chart unique
schema: schema:
type: string type: string
name: digest
responses: responses:
200: 200:
description: Chart thumbnail image description: Chart thumbnail image
@ -712,34 +736,6 @@ class ChartRestApi(BaseSupersetModelRestApi):
return self.response_404() return self.response_404()
current_user = get_current_user() current_user = get_current_user()
url = get_url_path("Superset.slice", slice_id=chart.id)
if kwargs["rison"].get("force", False):
logger.info(
"Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id)
)
cache_chart_thumbnail.delay(
current_user=current_user,
chart_id=chart.id,
force=True,
)
return self.response(202, message="OK Async")
# fetch the chart screenshot using the current user and cache if set
screenshot = ChartScreenshot(url, chart.digest).get_from_cache(
cache=thumbnail_cache
)
# If not screenshot then send request to compute thumb to celery
if not screenshot:
self.incr_stats("async", self.thumbnail.__name__)
logger.info(
"Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id)
)
cache_chart_thumbnail.delay(
current_user=current_user,
chart_id=chart.id,
force=True,
)
return self.response(202, message="OK Async")
# If digests
if chart.digest != digest: if chart.digest != digest:
self.incr_stats("redirect", self.thumbnail.__name__) self.incr_stats("redirect", self.thumbnail.__name__)
return redirect( return redirect(
@ -747,9 +743,34 @@ class ChartRestApi(BaseSupersetModelRestApi):
f"{self.__class__.__name__}.thumbnail", pk=pk, digest=chart.digest f"{self.__class__.__name__}.thumbnail", pk=pk, digest=chart.digest
) )
) )
url = get_url_path("Superset.slice", slice_id=chart.id)
screenshot_obj = ChartScreenshot(url, chart.digest)
cache_key = screenshot_obj.get_cache_key()
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
if cache_payload.should_trigger_task():
self.incr_stats("async", self.thumbnail.__name__)
logger.info(
"Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id)
)
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_chart_thumbnail.delay(
current_user=current_user,
chart_id=chart.id,
force=False,
)
return self.response(
202,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
self.incr_stats("from_cache", self.thumbnail.__name__) self.incr_stats("from_cache", self.thumbnail.__name__)
return Response( return Response(
FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True FileWrapper(cache_payload.get_image()),
mimetype="image/png",
direct_passthrough=True,
) )
@expose("/export/", methods=("GET",)) @expose("/export/", methods=("GET",))

View File

@ -304,6 +304,21 @@ class ChartCacheScreenshotResponseSchema(Schema):
image_url = fields.String( image_url = fields.String(
metadata={"description": "The url to fetch the screenshot"} metadata={"description": "The url to fetch the screenshot"}
) )
task_status = fields.String(
metadata={"description": "The status of the async screenshot"}
)
task_updated_at = fields.String(
metadata={"description": "The timestamp of the last change in status"}
)
class ChartGetCachedScreenshotResponseSchema(Schema):
task_status = fields.String(
metadata={"description": "The status of the async screenshot"}
)
task_updated_at = fields.String(
metadata={"description": "The timestamp of the last change in status"}
)
class ChartDataColumnSchema(Schema): class ChartDataColumnSchema(Schema):

View File

@ -729,8 +729,10 @@ THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str | None] |
THUMBNAIL_CACHE_CONFIG: CacheConfig = { THUMBNAIL_CACHE_CONFIG: CacheConfig = {
"CACHE_TYPE": "NullCache", "CACHE_TYPE": "NullCache",
"CACHE_DEFAULT_TIMEOUT": int(timedelta(days=7).total_seconds()),
"CACHE_NO_NULL_WARNING": True, "CACHE_NO_NULL_WARNING": True,
} }
THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds())
# Time before selenium times out after trying to locate an element on the page and wait # Time before selenium times out after trying to locate an element on the page and wait
# for that element to load for a screenshot. # for that element to load for a screenshot.

View File

@ -31,7 +31,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper from werkzeug.wsgi import FileWrapper
from superset import db, thumbnail_cache from superset import db
from superset.charts.schemas import ChartEntityResponseSchema from superset.charts.schemas import ChartEntityResponseSchema
from superset.commands.dashboard.copy import CopyDashboardCommand from superset.commands.dashboard.copy import CopyDashboardCommand
from superset.commands.dashboard.create import CreateDashboardCommand from superset.commands.dashboard.create import CreateDashboardCommand
@ -115,6 +115,7 @@ from superset.utils.pdf import build_pdf_from_screenshots
from superset.utils.screenshots import ( from superset.utils.screenshots import (
DashboardScreenshot, DashboardScreenshot,
DEFAULT_DASHBOARD_WINDOW_SIZE, DEFAULT_DASHBOARD_WINDOW_SIZE,
ScreenshotCachePayload,
) )
from superset.utils.urls import get_url_path from superset.utils.urls import get_url_path
from superset.views.base_api import ( from superset.views.base_api import (
@ -1022,110 +1023,6 @@ class DashboardRestApi(BaseSupersetModelRestApi):
response.set_cookie(token, "done", max_age=600) response.set_cookie(token, "done", max_age=600)
return response return response
@expose("/<pk>/thumbnail/<digest>/", methods=("GET",))
@validate_feature_flags(["THUMBNAILS"])
@protect()
@safe
@rison(thumbnail_query_schema)
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.thumbnail",
log_to_statsd=False,
)
def thumbnail(self, pk: int, digest: str, **kwargs: Any) -> WerkzeugResponse:
"""Compute async or get already computed dashboard thumbnail from cache.
---
get:
summary: Get dashboard's thumbnail
description: >-
Computes async or get already computed dashboard thumbnail from cache.
parameters:
- in: path
schema:
type: integer
name: pk
- in: path
name: digest
description: A hex digest that makes this dashboard unique
schema:
type: string
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/thumbnail_query_schema'
responses:
200:
description: Dashboard thumbnail image
content:
image/*:
schema:
type: string
format: binary
202:
description: Thumbnail does not exist on cache, fired async to compute
content:
application/json:
schema:
type: object
properties:
message:
type: string
302:
description: Redirects to the current digest
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters))
if not dashboard:
return self.response_404()
dashboard_url = get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
# If force, request a screenshot from the workers
current_user = get_current_user()
if kwargs["rison"].get("force", False):
cache_dashboard_thumbnail.delay(
current_user=current_user,
dashboard_id=dashboard.id,
force=True,
)
return self.response(202, message="OK Async")
# fetch the dashboard screenshot using the current user and cache if set
screenshot = DashboardScreenshot(
dashboard_url, dashboard.digest
).get_from_cache(cache=thumbnail_cache)
# If the screenshot does not exist, request one from the workers
if not screenshot:
self.incr_stats("async", self.thumbnail.__name__)
cache_dashboard_thumbnail.delay(
current_user=current_user,
dashboard_id=dashboard.id,
force=True,
)
return self.response(202, message="OK Async")
# If digests
if dashboard.digest != digest:
self.incr_stats("redirect", self.thumbnail.__name__)
return redirect(
url_for(
f"{self.__class__.__name__}.thumbnail",
pk=pk,
digest=dashboard.digest,
)
)
self.incr_stats("from_cache", self.thumbnail.__name__)
return Response(
FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True
)
@expose("/<pk>/cache_dashboard_screenshot/", methods=("POST",)) @expose("/<pk>/cache_dashboard_screenshot/", methods=("POST",))
@validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"]) @validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"])
@protect() @protect()
@ -1172,7 +1069,6 @@ class DashboardRestApi(BaseSupersetModelRestApi):
payload = CacheScreenshotSchema().load(request.json) payload = CacheScreenshotSchema().load(request.json)
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters)) dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters))
if not dashboard: if not dashboard:
return self.response_404() return self.response_404()
@ -1182,7 +1078,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
) )
# Don't shrink the image if thumb_size is not specified # Don't shrink the image if thumb_size is not specified
thumb_size = kwargs["rison"].get("thumb_size") or window_size thumb_size = kwargs["rison"].get("thumb_size") or window_size
force = kwargs["rison"].get("force", False)
dashboard_state: DashboardPermalinkState = { dashboard_state: DashboardPermalinkState = {
"dataMask": payload.get("dataMask", {}), "dataMask": payload.get("dataMask", {}),
"activeTabs": payload.get("activeTabs", []), "activeTabs": payload.get("activeTabs", []),
@ -1197,13 +1093,29 @@ class DashboardRestApi(BaseSupersetModelRestApi):
dashboard_url = get_url_path("Superset.dashboard_permalink", key=permalink_key) dashboard_url = get_url_path("Superset.dashboard_permalink", key=permalink_key)
screenshot_obj = DashboardScreenshot(dashboard_url, dashboard.digest) screenshot_obj = DashboardScreenshot(dashboard_url, dashboard.digest)
cache_key = screenshot_obj.cache_key(window_size, thumb_size, dashboard_state) cache_key = screenshot_obj.get_cache_key(
window_size, thumb_size, dashboard_state
)
image_url = get_url_path( image_url = get_url_path(
"DashboardRestApi.screenshot", pk=dashboard.id, digest=cache_key "DashboardRestApi.screenshot", pk=dashboard.id, digest=cache_key
) )
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
def trigger_celery() -> WerkzeugResponse: def build_response(status_code: int) -> WerkzeugResponse:
return self.response(
status_code,
cache_key=cache_key,
dashboard_url=dashboard_url,
image_url=image_url,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
if cache_payload.should_trigger_task(force):
logger.info("Triggering screenshot ASYNC") logger.info("Triggering screenshot ASYNC")
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_dashboard_screenshot.delay( cache_dashboard_screenshot.delay(
username=get_current_user(), username=get_current_user(),
guest_token=( guest_token=(
@ -1213,19 +1125,12 @@ class DashboardRestApi(BaseSupersetModelRestApi):
), ),
dashboard_id=dashboard.id, dashboard_id=dashboard.id,
dashboard_url=dashboard_url, dashboard_url=dashboard_url,
cache_key=cache_key,
force=False,
thumb_size=thumb_size, thumb_size=thumb_size,
window_size=window_size, window_size=window_size,
force=force,
) )
return self.response( return build_response(202)
202, return build_response(200)
cache_key=cache_key,
dashboard_url=dashboard_url,
image_url=image_url,
)
return trigger_celery()
@expose("/<pk>/screenshot/<digest>/", methods=("GET",)) @expose("/<pk>/screenshot/<digest>/", methods=("GET",))
@validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"]) @validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"])
@ -1282,9 +1187,12 @@ class DashboardRestApi(BaseSupersetModelRestApi):
# fetch the dashboard screenshot using the current user and cache if set # fetch the dashboard screenshot using the current user and cache if set
if img := DashboardScreenshot.get_from_cache_key(thumbnail_cache, digest): if cache_payload := DashboardScreenshot.get_from_cache_key(digest):
image = cache_payload.get_image()
if not image:
return self.response_404()
if download_format == "pdf": if download_format == "pdf":
pdf_img = img.getvalue() pdf_img = image.getvalue()
# Convert the screenshot to PDF # Convert the screenshot to PDF
pdf_data = build_pdf_from_screenshots([pdf_img]) pdf_data = build_pdf_from_screenshots([pdf_img])
@ -1296,13 +1204,120 @@ class DashboardRestApi(BaseSupersetModelRestApi):
) )
if download_format == "png": if download_format == "png":
return Response( return Response(
FileWrapper(img), FileWrapper(image),
mimetype="image/png", mimetype="image/png",
direct_passthrough=True, direct_passthrough=True,
) )
return self.response_404() return self.response_404()
@expose("/<pk>/thumbnail/<digest>/", methods=("GET",))
@validate_feature_flags(["THUMBNAILS"])
@protect()
@safe
@rison(thumbnail_query_schema)
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.thumbnail",
log_to_statsd=False,
)
def thumbnail(self, pk: int, digest: str, **kwargs: Any) -> WerkzeugResponse:
"""Compute async or get already computed dashboard thumbnail from cache.
---
get:
summary: Get dashboard's thumbnail
description: >-
Computes async or get already computed dashboard thumbnail from cache.
parameters:
- in: path
schema:
type: integer
name: pk
- in: path
name: digest
description: A hex digest that makes this dashboard unique
schema:
type: string
responses:
200:
description: Dashboard thumbnail image
content:
image/*:
schema:
type: string
format: binary
202:
description: Thumbnail does not exist on cache, fired async to compute
content:
application/json:
schema:
type: object
properties:
message:
type: string
302:
description: Redirects to the current digest
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters))
if not dashboard:
return self.response_404()
current_user = get_current_user()
dashboard_url = get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
if dashboard.digest != digest:
self.incr_stats("redirect", self.thumbnail.__name__)
return redirect(
url_for(
f"{self.__class__.__name__}.thumbnail",
pk=pk,
digest=dashboard.digest,
)
)
screenshot_obj = DashboardScreenshot(dashboard_url, dashboard.digest)
cache_key = screenshot_obj.get_cache_key()
cache_payload = (
screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload()
)
image_url = get_url_path(
"DashboardRestApi.thumbnail", pk=dashboard.id, digest=cache_key
)
if cache_payload.should_trigger_task():
self.incr_stats("async", self.thumbnail.__name__)
logger.info(
"Triggering thumbnail compute (dashboard id: %s) ASYNC",
str(dashboard.id),
)
screenshot_obj.cache.set(cache_key, ScreenshotCachePayload())
cache_dashboard_thumbnail.delay(
current_user=current_user,
dashboard_id=dashboard.id,
force=False,
)
return self.response(
202,
cache_key=cache_key,
dashboard_url=dashboard_url,
image_url=image_url,
task_updated_at=cache_payload.get_timestamp(),
task_status=cache_payload.get_status(),
)
self.incr_stats("from_cache", self.thumbnail.__name__)
return Response(
FileWrapper(cache_payload.get_image()),
mimetype="image/png",
direct_passthrough=True,
)
@expose("/favorite_status/", methods=("GET",)) @expose("/favorite_status/", methods=("GET",))
@protect() @protect()
@safe @safe

View File

@ -507,6 +507,12 @@ class DashboardCacheScreenshotResponseSchema(Schema):
image_url = fields.String( image_url = fields.String(
metadata={"description": "The url to fetch the screenshot"} metadata={"description": "The url to fetch the screenshot"}
) )
task_status = fields.String(
metadata={"description": "The status of the async screenshot"}
)
task_updated_at = fields.String(
metadata={"description": "The timestamp of the last change in status"}
)
class CacheScreenshotSchema(Schema): class CacheScreenshotSchema(Schema):

View File

@ -380,9 +380,7 @@ def event_after_chart_changed(
_mapper: Mapper, _connection: Connection, target: Slice _mapper: Mapper, _connection: Connection, target: Slice
) -> None: ) -> None:
cache_chart_thumbnail.delay( cache_chart_thumbnail.delay(
current_user=get_current_user(), current_user=get_current_user(), chart_id=target.id, force=True
chart_id=target.id,
force=True,
) )

View File

@ -38,7 +38,7 @@ logger = logging.getLogger(__name__)
def cache_chart_thumbnail( def cache_chart_thumbnail(
current_user: Optional[str], current_user: Optional[str],
chart_id: int, chart_id: int,
force: bool = False, force: bool,
window_size: Optional[WindowSize] = None, window_size: Optional[WindowSize] = None,
thumb_size: Optional[WindowSize] = None, thumb_size: Optional[WindowSize] = None,
) -> None: ) -> None:
@ -64,10 +64,9 @@ def cache_chart_thumbnail(
screenshot = ChartScreenshot(url, chart.digest) screenshot = ChartScreenshot(url, chart.digest)
screenshot.compute_and_cache( screenshot.compute_and_cache(
user=user, user=user,
cache=thumbnail_cache,
force=force,
window_size=window_size, window_size=window_size,
thumb_size=thumb_size, thumb_size=thumb_size,
force=force,
) )
return None return None
@ -76,7 +75,7 @@ def cache_chart_thumbnail(
def cache_dashboard_thumbnail( def cache_dashboard_thumbnail(
current_user: Optional[str], current_user: Optional[str],
dashboard_id: int, dashboard_id: int,
force: bool = False, force: bool,
thumb_size: Optional[WindowSize] = None, thumb_size: Optional[WindowSize] = None,
window_size: Optional[WindowSize] = None, window_size: Optional[WindowSize] = None,
) -> None: ) -> None:
@ -101,10 +100,9 @@ def cache_dashboard_thumbnail(
screenshot = DashboardScreenshot(url, dashboard.digest) screenshot = DashboardScreenshot(url, dashboard.digest)
screenshot.compute_and_cache( screenshot.compute_and_cache(
user=user, user=user,
cache=thumbnail_cache,
force=force,
window_size=window_size, window_size=window_size,
thumb_size=thumb_size, thumb_size=thumb_size,
force=force,
) )
@ -113,7 +111,7 @@ def cache_dashboard_screenshot( # pylint: disable=too-many-arguments
username: str, username: str,
dashboard_id: int, dashboard_id: int,
dashboard_url: str, dashboard_url: str,
force: bool = True, force: bool,
cache_key: Optional[str] = None, cache_key: Optional[str] = None,
guest_token: Optional[GuestToken] = None, guest_token: Optional[GuestToken] = None,
thumb_size: Optional[WindowSize] = None, thumb_size: Optional[WindowSize] = None,
@ -145,9 +143,8 @@ def cache_dashboard_screenshot( # pylint: disable=too-many-arguments
screenshot = DashboardScreenshot(dashboard_url, dashboard.digest) screenshot = DashboardScreenshot(dashboard_url, dashboard.digest)
screenshot.compute_and_cache( screenshot.compute_and_cache(
user=current_user, user=current_user,
cache=thumbnail_cache,
force=force,
window_size=window_size, window_size=window_size,
thumb_size=thumb_size, thumb_size=thumb_size,
cache_key=cache_key, cache_key=cache_key,
force=force,
) )

View File

@ -17,12 +17,14 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from datetime import datetime
from enum import Enum
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from flask import current_app from flask import current_app
from superset import feature_flag_manager from superset import app, feature_flag_manager, thumbnail_cache
from superset.dashboards.permalink.types import DashboardPermalinkState from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.extensions import event_logger from superset.extensions import event_logger
from superset.utils.hashing import md5_sha_from_dict from superset.utils.hashing import md5_sha_from_dict
@ -54,6 +56,70 @@ if TYPE_CHECKING:
from flask_caching import Cache from flask_caching import Cache
class StatusValues(Enum):
PENDING = "Pending"
COMPUTING = "Computing"
UPDATED = "Updated"
ERROR = "Error"
class ScreenshotCachePayload:
def __init__(self, image: bytes | None = None):
self._image = image
self._timestamp = datetime.now().isoformat()
self.status = StatusValues.PENDING
if image:
self.status = StatusValues.UPDATED
def update_timestamp(self) -> None:
self._timestamp = datetime.now().isoformat()
def pending(self) -> None:
self.update_timestamp()
self._image = None
self.status = StatusValues.PENDING
def computing(self) -> None:
self.update_timestamp()
self._image = None
self.status = StatusValues.COMPUTING
def update(self, image: bytes) -> None:
self.update_timestamp()
self.status = StatusValues.UPDATED
self._image = image
def error(
self,
) -> None:
self.update_timestamp()
self.status = StatusValues.ERROR
def get_image(self) -> BytesIO | None:
if not self._image:
return None
return BytesIO(self._image)
def get_timestamp(self) -> str:
return self._timestamp
def get_status(self) -> str:
return self.status.value
def is_error_cache_ttl_expired(self) -> bool:
error_cache_ttl = app.config["THUMBNAIL_ERROR_CACHE_TTL"]
return (
datetime.now() - datetime.fromisoformat(self.get_timestamp())
).total_seconds() > error_cache_ttl
def should_trigger_task(self, force: bool = False) -> bool:
return (
force
or self.status == StatusValues.PENDING
or (self.status == StatusValues.ERROR and self.is_error_cache_ttl_expired())
)
class BaseScreenshot: class BaseScreenshot:
driver_type = current_app.config["WEBDRIVER_TYPE"] driver_type = current_app.config["WEBDRIVER_TYPE"]
url: str url: str
@ -63,6 +129,7 @@ class BaseScreenshot:
element: str = "" element: str = ""
window_size: WindowSize = DEFAULT_SCREENSHOT_WINDOW_SIZE window_size: WindowSize = DEFAULT_SCREENSHOT_WINDOW_SIZE
thumb_size: WindowSize = DEFAULT_SCREENSHOT_THUMBNAIL_SIZE thumb_size: WindowSize = DEFAULT_SCREENSHOT_THUMBNAIL_SIZE
cache: Cache = thumbnail_cache
def __init__(self, url: str, digest: str | None): def __init__(self, url: str, digest: str | None):
self.digest = digest self.digest = digest
@ -75,7 +142,14 @@ class BaseScreenshot:
return WebDriverPlaywright(self.driver_type, window_size) return WebDriverPlaywright(self.driver_type, window_size)
return WebDriverSelenium(self.driver_type, window_size) return WebDriverSelenium(self.driver_type, window_size)
def cache_key( def get_screenshot(
self, user: User, window_size: WindowSize | None = None
) -> bytes | None:
driver = self.driver(window_size)
self.screenshot = driver.get_screenshot(self.url, self.element, user)
return self.screenshot
def get_cache_key(
self, self,
window_size: bool | WindowSize | None = None, window_size: bool | WindowSize | None = None,
thumb_size: bool | WindowSize | None = None, thumb_size: bool | WindowSize | None = None,
@ -91,69 +165,35 @@ class BaseScreenshot:
} }
return md5_sha_from_dict(args) return md5_sha_from_dict(args)
def get_screenshot(
self, user: User, window_size: WindowSize | None = None
) -> bytes | None:
driver = self.driver(window_size)
with event_logger.log_context("screenshot", screenshot_url=self.url):
self.screenshot = driver.get_screenshot(self.url, self.element, user)
return self.screenshot
def get(
self,
user: User = None,
cache: Cache = None,
thumb_size: WindowSize | None = None,
) -> BytesIO | None:
"""
Get thumbnail screenshot has BytesIO from cache or fetch
:param user: None to use current user or User Model to login and fetch
:param cache: The cache to use
:param thumb_size: Override thumbnail site
"""
payload: bytes | None = None
cache_key = self.cache_key(self.window_size, thumb_size)
if cache:
payload = cache.get(cache_key)
if not payload:
payload = self.compute_and_cache(
user=user, thumb_size=thumb_size, cache=cache
)
else:
logger.info("Loaded thumbnail from cache: %s", cache_key)
if payload:
return BytesIO(payload)
return None
def get_from_cache( def get_from_cache(
self, self,
cache: Cache,
window_size: WindowSize | None = None, window_size: WindowSize | None = None,
thumb_size: WindowSize | None = None, thumb_size: WindowSize | None = None,
) -> BytesIO | None: ) -> ScreenshotCachePayload | None:
cache_key = self.cache_key(window_size, thumb_size) cache_key = self.get_cache_key(window_size, thumb_size)
return self.get_from_cache_key(cache, cache_key) return self.get_from_cache_key(cache_key)
@staticmethod @classmethod
def get_from_cache_key(cache: Cache, cache_key: str) -> BytesIO | None: def get_from_cache_key(cls, cache_key: str) -> ScreenshotCachePayload | None:
logger.info("Attempting to get from cache: %s", cache_key) logger.info("Attempting to get from cache: %s", cache_key)
if payload := cache.get(cache_key): if payload := cls.cache.get(cache_key):
return BytesIO(payload) # for backwards compatability, byte objects should be converted
if not isinstance(payload, ScreenshotCachePayload):
payload = ScreenshotCachePayload(payload)
return payload
logger.info("Failed at getting from cache: %s", cache_key) logger.info("Failed at getting from cache: %s", cache_key)
return None return None
def compute_and_cache( # pylint: disable=too-many-arguments def compute_and_cache( # pylint: disable=too-many-arguments
self, self,
force: bool,
user: User = None, user: User = None,
window_size: WindowSize | None = None, window_size: WindowSize | None = None,
thumb_size: WindowSize | None = None, thumb_size: WindowSize | None = None,
cache: Cache = None,
force: bool = True,
cache_key: str | None = None, cache_key: str | None = None,
) -> bytes | None: ) -> None:
""" """
Fetches the screenshot, computes the thumbnail and caches the result Computes the thumbnail and caches the result
:param user: If no user is given will use the current context :param user: If no user is given will use the current context
:param cache: The cache to keep the thumbnail payload :param cache: The cache to keep the thumbnail payload
@ -162,40 +202,46 @@ class BaseScreenshot:
:param force: Will force the computation even if it's already cached :param force: Will force the computation even if it's already cached
:return: Image payload :return: Image payload
""" """
cache_key = cache_key or self.cache_key(window_size, thumb_size) cache_key = cache_key or self.get_cache_key(window_size, thumb_size)
cache_payload = self.get_from_cache_key(cache_key) or ScreenshotCachePayload()
if (
cache_payload.status in [StatusValues.COMPUTING, StatusValues.UPDATED]
and not force
):
logger.info(
"Skipping compute - already processed for thumbnail: %s", cache_key
)
return
window_size = window_size or self.window_size window_size = window_size or self.window_size
thumb_size = thumb_size or self.thumb_size thumb_size = thumb_size or self.thumb_size
if not force and cache and cache.get(cache_key):
logger.info("Thumb already cached, skipping...")
return None
logger.info("Processing url for thumbnail: %s", cache_key) logger.info("Processing url for thumbnail: %s", cache_key)
cache_payload.computing()
payload = None self.cache.set(cache_key, cache_payload)
image = None
# Assuming all sorts of things can go wrong with Selenium # Assuming all sorts of things can go wrong with Selenium
try: try:
with event_logger.log_context( logger.info("trying to generate screenshot")
f"screenshot.compute.{self.thumbnail_type}", force=force with event_logger.log_context(f"screenshot.compute.{self.thumbnail_type}"):
): image = self.get_screenshot(user=user, window_size=window_size)
payload = self.get_screenshot(user=user, window_size=window_size)
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed at generating thumbnail %s", ex, exc_info=True) logger.warning("Failed at generating thumbnail %s", ex, exc_info=True)
cache_payload.error()
if payload and window_size != thumb_size: if image and window_size != thumb_size:
try: try:
payload = self.resize_image(payload, thumb_size=thumb_size) image = self.resize_image(image, thumb_size=thumb_size)
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed at resizing thumbnail %s", ex, exc_info=True) logger.warning("Failed at resizing thumbnail %s", ex, exc_info=True)
payload = None cache_payload.error()
image = None
if payload: if image:
logger.info("Caching thumbnail: %s", cache_key) logger.info("Caching thumbnail: %s", cache_key)
with event_logger.log_context( with event_logger.log_context(f"screenshot.cache.{self.thumbnail_type}"):
f"screenshot.cache.{self.thumbnail_type}", force=force cache_payload.update(image)
): self.cache.set(cache_key, cache_payload)
cache.set(cache_key, payload) logger.info("Updated thumbnail cache; Status: %s", cache_payload.get_status())
logger.info("Done caching thumbnail") return
return payload
@classmethod @classmethod
def resize_image( def resize_image(
@ -265,7 +311,7 @@ class DashboardScreenshot(BaseScreenshot):
self.window_size = window_size or DEFAULT_DASHBOARD_WINDOW_SIZE self.window_size = window_size or DEFAULT_DASHBOARD_WINDOW_SIZE
self.thumb_size = thumb_size or DEFAULT_DASHBOARD_THUMBNAIL_SIZE self.thumb_size = thumb_size or DEFAULT_DASHBOARD_THUMBNAIL_SIZE
def cache_key( def get_cache_key(
self, self,
window_size: bool | WindowSize | None = None, window_size: bool | WindowSize | None = None,
thumb_size: bool | WindowSize | None = None, thumb_size: bool | WindowSize | None = None,

View File

@ -380,7 +380,7 @@ class WebDriverSelenium(WebDriverProxy):
return error_messages return error_messages
def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: # noqa: C901
driver = self.auth(user) driver = self.auth(user)
driver.set_window_size(*self._window) driver.set_window_size(*self._window)
driver.get(url) driver.get(url)
@ -411,6 +411,7 @@ class WebDriverSelenium(WebDriverProxy):
) )
) )
except TimeoutException: except TimeoutException:
logger.info("Timeout Exception caught")
# Fallback to allow a screenshot of an empty dashboard # Fallback to allow a screenshot of an empty dashboard
try: try:
WebDriverWait(driver, 0).until( WebDriverWait(driver, 0).until(
@ -461,18 +462,23 @@ class WebDriverSelenium(WebDriverProxy):
) )
img = element.screenshot_as_png img = element.screenshot_as_png
except Exception as ex:
logger.warning("exception in webdriver", exc_info=ex)
raise
except TimeoutException: except TimeoutException:
# raise again for the finally block, but handled above # raise again for the finally block, but handled above
pass raise
except StaleElementReferenceException: except StaleElementReferenceException:
logger.exception( logger.exception(
"Selenium got a stale element while requesting url %s", "Selenium got a stale element while requesting url %s",
url, url,
) )
raise
except WebDriverException: except WebDriverException:
logger.exception( logger.exception(
"Encountered an unexpected error when requesting url %s", url "Encountered an unexpected error when requesting url %s", url
) )
raise
finally: finally:
self.destroy(driver, current_app.config["SCREENSHOT_SELENIUM_RETRIES"]) self.destroy(driver, current_app.config["SCREENSHOT_SELENIUM_RETRIES"])
return img return img

View File

@ -319,9 +319,5 @@ def test_compute_thumbnails(thumbnail_mock, app_context, fs):
["-d", "-i", dashboard.id], ["-d", "-i", dashboard.id],
) )
thumbnail_mock.assert_called_with( thumbnail_mock.assert_called_with(None, dashboard.id, force=False)
None,
dashboard.id,
force=False,
)
assert response.exit_code == 0 assert response.exit_code == 0

View File

@ -37,6 +37,7 @@ from superset.reports.models import ReportSchedule, ReportScheduleType
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.tags.models import Tag, TaggedObject, TagType, ObjectType from superset.tags.models import Tag, TaggedObject, TagType, ObjectType
from superset.utils.core import backend, override_user from superset.utils.core import backend, override_user
from superset.utils.screenshots import ScreenshotCachePayload
from superset.utils import json from superset.utils import json
from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
@ -3069,13 +3070,15 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
@pytest.mark.usefixtures("create_dashboard_with_tag") @pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot") @patch("superset.dashboards.api.cache_dashboard_screenshot")
@patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key") @patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key")
def test_screenshot_success_png(self, mock_get_cache, mock_cache_task): def test_screenshot_success_png(self, mock_get_from_cache_key, mock_cache_task):
""" """
Validate screenshot returns png Validate screenshot returns png
""" """
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
mock_cache_task.return_value = None mock_cache_task.return_value = None
mock_get_cache.return_value = BytesIO(b"fake image data") mock_get_from_cache_key.return_value = ScreenshotCachePayload(
b"fake image data"
)
dashboard = ( dashboard = (
db.session.query(Dashboard) db.session.query(Dashboard)
@ -3083,7 +3086,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
.first() .first()
) )
cache_resp = self._cache_screenshot(dashboard.id) cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202 assert cache_resp.status_code == 200
cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"]
response = self._get_screenshot(dashboard.id, cache_key, "png") response = self._get_screenshot(dashboard.id, cache_key, "png")
@ -3091,20 +3094,29 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
assert response.mimetype == "image/png" assert response.mimetype == "image/png"
assert response.data == b"fake image data" assert response.data == b"fake image data"
mock_get_from_cache_key.return_value = ScreenshotCachePayload()
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
@with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True) @with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True)
@pytest.mark.usefixtures("create_dashboard_with_tag") @pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot") @patch("superset.dashboards.api.cache_dashboard_screenshot")
@patch("superset.dashboards.api.build_pdf_from_screenshots") @patch("superset.dashboards.api.build_pdf_from_screenshots")
@patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key") @patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key")
def test_screenshot_success_pdf( def test_screenshot_success_pdf(
self, mock_get_from_cache, mock_build_pdf, mock_cache_task self,
mock_get_from_cache_key,
mock_build_pdf,
mock_cache_task,
): ):
""" """
Validate screenshot can return pdf. Validate screenshot can return pdf.
""" """
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
mock_cache_task.return_value = None mock_cache_task.return_value = None
mock_get_from_cache.return_value = BytesIO(b"fake image data") mock_get_from_cache_key.return_value = ScreenshotCachePayload(
b"fake image data"
)
mock_build_pdf.return_value = b"fake pdf data" mock_build_pdf.return_value = b"fake pdf data"
dashboard = ( dashboard = (
@ -3113,7 +3125,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
.first() .first()
) )
cache_resp = self._cache_screenshot(dashboard.id) cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202 assert cache_resp.status_code == 200
cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"]
response = self._get_screenshot(dashboard.id, cache_key, "pdf") response = self._get_screenshot(dashboard.id, cache_key, "pdf")
@ -3121,6 +3133,10 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
assert response.mimetype == "application/pdf" assert response.mimetype == "application/pdf"
assert response.data == b"fake pdf data" assert response.data == b"fake pdf data"
mock_get_from_cache_key.return_value = ScreenshotCachePayload()
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
@with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True) @with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True)
@pytest.mark.usefixtures("create_dashboard_with_tag") @pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot") @patch("superset.dashboards.api.cache_dashboard_screenshot")
@ -3153,10 +3169,12 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
@pytest.mark.usefixtures("create_dashboard_with_tag") @pytest.mark.usefixtures("create_dashboard_with_tag")
@patch("superset.dashboards.api.cache_dashboard_screenshot") @patch("superset.dashboards.api.cache_dashboard_screenshot")
@patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key") @patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key")
def test_screenshot_invalid_download_format(self, mock_get_cache, mock_cache_task): def test_screenshot_invalid_download_format(
self, mock_get_from_cache_key, mock_cache_task
):
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
mock_cache_task.return_value = None mock_cache_task.return_value = None
mock_get_cache.return_value = BytesIO(b"fake png data") mock_get_from_cache_key.return_value = ScreenshotCachePayload(b"fake png data")
dashboard = ( dashboard = (
db.session.query(Dashboard) db.session.query(Dashboard)
@ -3165,9 +3183,13 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas
) )
cache_resp = self._cache_screenshot(dashboard.id) cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202 assert cache_resp.status_code == 200
cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"]
mock_get_from_cache_key.return_value = ScreenshotCachePayload()
cache_resp = self._cache_screenshot(dashboard.id)
assert cache_resp.status_code == 202
response = self._get_screenshot(dashboard.id, cache_key, "invalid") response = self._get_screenshot(dashboard.id, cache_key, "invalid")
assert response.status_code == 404 assert response.status_code == 404

View File

@ -18,7 +18,6 @@
# from superset.models.dashboard import Dashboard # from superset.models.dashboard import Dashboard
import urllib.request import urllib.request
from io import BytesIO
from unittest import skipUnless from unittest import skipUnless
from unittest.mock import ANY, call, MagicMock, patch from unittest.mock import ANY, call, MagicMock, patch
@ -32,7 +31,11 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.tasks.types import ExecutorType, FixedExecutor from superset.tasks.types import ExecutorType, FixedExecutor
from superset.utils import json from superset.utils import json
from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot from superset.utils.screenshots import (
ChartScreenshot,
DashboardScreenshot,
ScreenshotCachePayload,
)
from superset.utils.urls import get_url_path from superset.utils.urls import get_url_path
from superset.utils.webdriver import WebDriverSelenium from superset.utils.webdriver import WebDriverSelenium
from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.base_tests import SupersetTestCase
@ -287,14 +290,14 @@ class TestThumbnails(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_not_allowed(self): def test_get_async_dashboard_created(self):
""" """
Thumbnails: Simple get async dashboard not allowed Thumbnails: Simple get async dashboard not allowed
""" """
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) _, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
rv = self.client.get(thumbnail_url) rv = self.client.get(thumbnail_url)
assert rv.status_code == 404 assert rv.status_code == 202
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(THUMBNAILS=True) @with_feature_flags(THUMBNAILS=True)
@ -370,7 +373,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get chart with wrong digest Thumbnails: Simple get chart with wrong digest
""" """
with patch.object( with patch.object(
ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image) ChartScreenshot,
"get_from_cache",
return_value=ScreenshotCachePayload(self.mock_image),
): ):
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL) id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
@ -385,7 +390,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get cached dashboard screenshot Thumbnails: Simple get cached dashboard screenshot
""" """
with patch.object( with patch.object(
DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image) DashboardScreenshot,
"get_from_cache_key",
return_value=ScreenshotCachePayload(self.mock_image),
): ):
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) _, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)
@ -400,7 +407,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get cached chart screenshot Thumbnails: Simple get cached chart screenshot
""" """
with patch.object( with patch.object(
ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image) ChartScreenshot,
"get_from_cache_key",
return_value=ScreenshotCachePayload(self.mock_image),
): ):
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL) id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL)
@ -415,7 +424,9 @@ class TestThumbnails(SupersetTestCase):
Thumbnails: Simple get dashboard with wrong digest Thumbnails: Simple get dashboard with wrong digest
""" """
with patch.object( with patch.object(
DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image) DashboardScreenshot,
"get_from_cache",
return_value=ScreenshotCachePayload(self.mock_image),
): ):
self.login(ADMIN_USERNAME) self.login(ADMIN_USERNAME)
id_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) id_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL)

View File

@ -0,0 +1,194 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset.utils.hashing import md5_sha_from_dict
from superset.utils.screenshots import (
BaseScreenshot,
ScreenshotCachePayload,
StatusValues,
)
BASE_SCREENSHOT_PATH = "superset.utils.screenshots.BaseScreenshot"
class MockCache:
"""A class to manage screenshot cache."""
def __init__(self):
self._cache = None # Store the cached value
def set(self, _key, value):
"""Set the cache with a new value."""
self._cache = value
def get(self, _key):
"""Get the cached value."""
return self._cache
@pytest.fixture
def mock_user():
"""Fixture to create a mock user."""
mock_user = MagicMock()
mock_user.id = 1
return mock_user
@pytest.fixture
def screenshot_obj():
"""Fixture to create a BaseScreenshot object."""
url = "http://example.com"
digest = "sample_digest"
return BaseScreenshot(url, digest)
def test_get_screenshot(mocker: MockerFixture, screenshot_obj):
"""Get screenshot should return a Bytes object"""
fake_bytes = b"fake_screenshot_data"
driver = mocker.patch(BASE_SCREENSHOT_PATH + ".driver")
driver.return_value.get_screenshot.return_value = fake_bytes
screenshot_data = screenshot_obj.get_screenshot(mock_user)
assert screenshot_data == fake_bytes
def test_get_cache_key(screenshot_obj):
"""Test get_cache_key method"""
expected_cache_key = md5_sha_from_dict(
{
"thumbnail_type": "",
"digest": screenshot_obj.digest,
"type": "thumb",
"window_size": screenshot_obj.window_size,
"thumb_size": screenshot_obj.thumb_size,
}
)
cache_key = screenshot_obj.get_cache_key()
assert cache_key == expected_cache_key
def test_get_from_cache_key(mocker: MockerFixture, screenshot_obj):
"""get_from_cache_key should always return a ScreenshotCachePayload Object"""
# backwards compatability test for retrieving plain bytes
fake_bytes = b"fake_screenshot_data"
BaseScreenshot.cache = MockCache()
BaseScreenshot.cache.set("key", fake_bytes)
cache_payload = screenshot_obj.get_from_cache_key("key")
assert isinstance(cache_payload, ScreenshotCachePayload)
assert cache_payload._image == fake_bytes # pylint: disable=protected-access
class TestComputeAndCache:
def _setup_compute_and_cache(self, mocker: MockerFixture, screenshot_obj):
"""Helper method to handle the common setup for the tests."""
# Patch the methods
get_from_cache_key = mocker.patch(
BASE_SCREENSHOT_PATH + ".get_from_cache_key", return_value=None
)
get_screenshot = mocker.patch(
BASE_SCREENSHOT_PATH + ".get_screenshot", return_value=b"new_image_data"
)
resize_image = mocker.patch(
BASE_SCREENSHOT_PATH + ".resize_image", return_value=b"resized_image_data"
)
BaseScreenshot.cache = MockCache()
return {
"get_from_cache_key": get_from_cache_key,
"get_screenshot": get_screenshot,
"resize_image": resize_image,
}
def test_happy_path(self, mocker: MockerFixture, screenshot_obj):
self._setup_compute_and_cache(mocker, screenshot_obj)
screenshot_obj.compute_and_cache(force=False)
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.UPDATED
def test_screenshot_error(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
get_screenshot: MagicMock = mocks.get("get_screenshot")
get_screenshot.side_effect = Exception
screenshot_obj.compute_and_cache(force=False)
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.ERROR
def test_resize_error(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
resize_image: MagicMock = mocks.get("resize_image")
resize_image.side_effect = Exception
screenshot_obj.compute_and_cache(force=False)
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.ERROR
def test_skips_if_computing(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
cached_value = ScreenshotCachePayload()
cached_value.computing()
get_from_cache_key = mocks.get("get_from_cache_key")
get_from_cache_key.return_value = cached_value
# Ensure that it skips when thumbnail status is computing
screenshot_obj.compute_and_cache(force=False)
get_screenshot = mocks.get("get_screenshot")
get_screenshot.assert_not_called()
# Ensure that it processes when force = True
screenshot_obj.compute_and_cache(force=True)
get_screenshot.assert_called_once()
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload.status == StatusValues.UPDATED
def test_skips_if_updated(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
cached_value = ScreenshotCachePayload(image=b"initial_value")
get_from_cache_key = mocks.get("get_from_cache_key")
get_from_cache_key.return_value = cached_value
# Ensure that it skips when thumbnail status is updated
window_size = thumb_size = (10, 10)
screenshot_obj.compute_and_cache(
force=False, window_size=window_size, thumb_size=thumb_size
)
get_screenshot = mocks.get("get_screenshot")
get_screenshot.assert_not_called()
# Ensure that it processes when force = True
screenshot_obj.compute_and_cache(
force=True, window_size=window_size, thumb_size=thumb_size
)
get_screenshot.assert_called_once()
cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key")
assert cache_payload._image != b"initial_value"
def test_resize(self, mocker: MockerFixture, screenshot_obj):
mocks = self._setup_compute_and_cache(mocker, screenshot_obj)
window_size = thumb_size = (10, 10)
resize_image: MagicMock = mocks.get("resize_image")
screenshot_obj.compute_and_cache(
force=False, window_size=window_size, thumb_size=thumb_size
)
resize_image.assert_not_called()
screenshot_obj.compute_and_cache(
force=False, window_size=(1, 1), thumb_size=thumb_size
)
resize_image.assert_called_once()