From 7db0589340603b03246d9f2f4e37233ece968a14 Mon Sep 17 00:00:00 2001 From: Jack <41238731+fisjac@users.noreply.github.com> Date: Fri, 31 Jan 2025 12:22:31 -0600 Subject: [PATCH] fix(thumbnail cache): Enabling force parameter on screenshot/thumbnail cache (#31757) Co-authored-by: Kamil Gabryjelski --- superset/charts/api.py | 119 ++++---- superset/charts/schemas.py | 15 + superset/config.py | 2 + superset/dashboards/api.py | 261 +++++++++--------- superset/dashboards/schemas.py | 6 + superset/models/slice.py | 4 +- superset/tasks/thumbnails.py | 15 +- superset/utils/screenshots.py | 190 ++++++++----- superset/utils/webdriver.py | 10 +- tests/integration_tests/cli_tests.py | 6 +- .../integration_tests/dashboards/api_tests.py | 40 ++- tests/integration_tests/thumbnails_tests.py | 27 +- tests/unit_tests/utils/screenshot_test.py | 194 +++++++++++++ 13 files changed, 609 insertions(+), 280 deletions(-) create mode 100644 tests/unit_tests/utils/screenshot_test.py diff --git a/superset/charts/api.py b/superset/charts/api.py index c7266866e..a600a3ca7 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -30,7 +30,7 @@ from marshmallow import ValidationError from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper -from superset import app, is_feature_enabled, thumbnail_cache +from superset import app, is_feature_enabled from superset.charts.filters import ( ChartAllTextFilter, ChartCertifiedFilter, @@ -84,7 +84,12 @@ from superset.models.slice import Slice from superset.tasks.thumbnails import cache_chart_thumbnail from superset.tasks.utils import get_current_user from superset.utils import json -from superset.utils.screenshots import ChartScreenshot, DEFAULT_CHART_WINDOW_SIZE +from superset.utils.screenshots import ( + ChartScreenshot, + DEFAULT_CHART_WINDOW_SIZE, + ScreenshotCachePayload, + StatusValues, +) from superset.utils.urls import get_url_path from superset.views.base_api import ( BaseSupersetModelRestApi, @@ -564,8 +569,14 @@ class ChartRestApi(BaseSupersetModelRestApi): schema: $ref: '#/components/schemas/screenshot_query_schema' responses: + 200: + description: Chart async result + content: + application/json: + schema: + $ref: "#/components/schemas/ChartCacheScreenshotResponseSchema" 202: - description: Chart async result + description: Chart screenshot task created content: application/json: schema: @@ -580,6 +591,7 @@ class ChartRestApi(BaseSupersetModelRestApi): $ref: '#/components/responses/500' """ rison_dict = kwargs["rison"] + force = rison_dict.get("force") window_size = rison_dict.get("window_size") or DEFAULT_CHART_WINDOW_SIZE # Don't shrink the image if thumb_size is not specified @@ -591,25 +603,36 @@ class ChartRestApi(BaseSupersetModelRestApi): chart_url = get_url_path("Superset.slice", slice_id=chart.id) screenshot_obj = ChartScreenshot(chart_url, chart.digest) - cache_key = screenshot_obj.cache_key(window_size, thumb_size) + cache_key = screenshot_obj.get_cache_key(window_size, thumb_size) + cache_payload = ( + screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload() + ) image_url = get_url_path( "ChartRestApi.screenshot", pk=chart.id, digest=cache_key ) - def trigger_celery() -> WerkzeugResponse: + def build_response(status_code: int) -> WerkzeugResponse: + return self.response( + status_code, + cache_key=cache_key, + chart_url=chart_url, + image_url=image_url, + task_updated_at=cache_payload.get_timestamp(), + task_status=cache_payload.get_status(), + ) + + if cache_payload.should_trigger_task(force): logger.info("Triggering screenshot ASYNC") + screenshot_obj.cache.set(cache_key, ScreenshotCachePayload()) cache_chart_thumbnail.delay( current_user=get_current_user(), chart_id=chart.id, - force=True, window_size=window_size, thumb_size=thumb_size, + force=force, ) - return self.response( - 202, cache_key=cache_key, chart_url=chart_url, image_url=image_url - ) - - return trigger_celery() + return build_response(202) + return build_response(200) @expose("//screenshot//", methods=("GET",)) @protect() @@ -635,7 +658,7 @@ class ChartRestApi(BaseSupersetModelRestApi): name: digest responses: 200: - description: Chart thumbnail image + description: Chart screenshot image content: image/*: schema: @@ -652,16 +675,16 @@ class ChartRestApi(BaseSupersetModelRestApi): """ chart = self.datamodel.get(pk, self._base_filters) - # Making sure the chart still exists if not chart: return self.response_404() - # fetch the chart screenshot using the current user and cache if set - if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest): - return Response( - FileWrapper(img), mimetype="image/png", direct_passthrough=True - ) - # TODO: return an empty image + if cache_payload := ChartScreenshot.get_from_cache_key(digest): + if cache_payload.status == StatusValues.UPDATED: + return Response( + FileWrapper(cache_payload.get_image()), + mimetype="image/png", + direct_passthrough=True, + ) return self.response_404() @expose("//thumbnail//", methods=("GET",)) @@ -685,9 +708,10 @@ class ChartRestApi(BaseSupersetModelRestApi): type: integer name: pk - in: path + name: digest + description: A hex digest that makes this chart unique schema: type: string - name: digest responses: 200: description: Chart thumbnail image @@ -712,34 +736,6 @@ class ChartRestApi(BaseSupersetModelRestApi): return self.response_404() current_user = get_current_user() - url = get_url_path("Superset.slice", slice_id=chart.id) - if kwargs["rison"].get("force", False): - logger.info( - "Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id) - ) - cache_chart_thumbnail.delay( - current_user=current_user, - chart_id=chart.id, - force=True, - ) - return self.response(202, message="OK Async") - # fetch the chart screenshot using the current user and cache if set - screenshot = ChartScreenshot(url, chart.digest).get_from_cache( - cache=thumbnail_cache - ) - # If not screenshot then send request to compute thumb to celery - if not screenshot: - self.incr_stats("async", self.thumbnail.__name__) - logger.info( - "Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id) - ) - cache_chart_thumbnail.delay( - current_user=current_user, - chart_id=chart.id, - force=True, - ) - return self.response(202, message="OK Async") - # If digests if chart.digest != digest: self.incr_stats("redirect", self.thumbnail.__name__) return redirect( @@ -747,9 +743,34 @@ class ChartRestApi(BaseSupersetModelRestApi): f"{self.__class__.__name__}.thumbnail", pk=pk, digest=chart.digest ) ) + url = get_url_path("Superset.slice", slice_id=chart.id) + screenshot_obj = ChartScreenshot(url, chart.digest) + cache_key = screenshot_obj.get_cache_key() + cache_payload = ( + screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload() + ) + + if cache_payload.should_trigger_task(): + self.incr_stats("async", self.thumbnail.__name__) + logger.info( + "Triggering thumbnail compute (chart id: %s) ASYNC", str(chart.id) + ) + screenshot_obj.cache.set(cache_key, ScreenshotCachePayload()) + cache_chart_thumbnail.delay( + current_user=current_user, + chart_id=chart.id, + force=False, + ) + return self.response( + 202, + task_updated_at=cache_payload.get_timestamp(), + task_status=cache_payload.get_status(), + ) self.incr_stats("from_cache", self.thumbnail.__name__) return Response( - FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True + FileWrapper(cache_payload.get_image()), + mimetype="image/png", + direct_passthrough=True, ) @expose("/export/", methods=("GET",)) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 5531e057c..7faa42aeb 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -304,6 +304,21 @@ class ChartCacheScreenshotResponseSchema(Schema): image_url = fields.String( metadata={"description": "The url to fetch the screenshot"} ) + task_status = fields.String( + metadata={"description": "The status of the async screenshot"} + ) + task_updated_at = fields.String( + metadata={"description": "The timestamp of the last change in status"} + ) + + +class ChartGetCachedScreenshotResponseSchema(Schema): + task_status = fields.String( + metadata={"description": "The status of the async screenshot"} + ) + task_updated_at = fields.String( + metadata={"description": "The timestamp of the last change in status"} + ) class ChartDataColumnSchema(Schema): diff --git a/superset/config.py b/superset/config.py index 1182042ec..6362e39ae 100644 --- a/superset/config.py +++ b/superset/config.py @@ -729,8 +729,10 @@ THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str | None] | THUMBNAIL_CACHE_CONFIG: CacheConfig = { "CACHE_TYPE": "NullCache", + "CACHE_DEFAULT_TIMEOUT": int(timedelta(days=7).total_seconds()), "CACHE_NO_NULL_WARNING": True, } +THUMBNAIL_ERROR_CACHE_TTL = int(timedelta(days=1).total_seconds()) # Time before selenium times out after trying to locate an element on the page and wait # for that element to load for a screenshot. diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index ab83a01d9..c8c744ec6 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -31,7 +31,7 @@ from marshmallow import ValidationError from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper -from superset import db, thumbnail_cache +from superset import db from superset.charts.schemas import ChartEntityResponseSchema from superset.commands.dashboard.copy import CopyDashboardCommand from superset.commands.dashboard.create import CreateDashboardCommand @@ -115,6 +115,7 @@ from superset.utils.pdf import build_pdf_from_screenshots from superset.utils.screenshots import ( DashboardScreenshot, DEFAULT_DASHBOARD_WINDOW_SIZE, + ScreenshotCachePayload, ) from superset.utils.urls import get_url_path from superset.views.base_api import ( @@ -1022,110 +1023,6 @@ class DashboardRestApi(BaseSupersetModelRestApi): response.set_cookie(token, "done", max_age=600) return response - @expose("//thumbnail//", methods=("GET",)) - @validate_feature_flags(["THUMBNAILS"]) - @protect() - @safe - @rison(thumbnail_query_schema) - @event_logger.log_this_with_context( - action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.thumbnail", - log_to_statsd=False, - ) - def thumbnail(self, pk: int, digest: str, **kwargs: Any) -> WerkzeugResponse: - """Compute async or get already computed dashboard thumbnail from cache. - --- - get: - summary: Get dashboard's thumbnail - description: >- - Computes async or get already computed dashboard thumbnail from cache. - parameters: - - in: path - schema: - type: integer - name: pk - - in: path - name: digest - description: A hex digest that makes this dashboard unique - schema: - type: string - - in: query - name: q - content: - application/json: - schema: - $ref: '#/components/schemas/thumbnail_query_schema' - responses: - 200: - description: Dashboard thumbnail image - content: - image/*: - schema: - type: string - format: binary - 202: - description: Thumbnail does not exist on cache, fired async to compute - content: - application/json: - schema: - type: object - properties: - message: - type: string - 302: - description: Redirects to the current digest - 401: - $ref: '#/components/responses/401' - 404: - $ref: '#/components/responses/404' - 422: - $ref: '#/components/responses/422' - 500: - $ref: '#/components/responses/500' - """ - dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters)) - if not dashboard: - return self.response_404() - - dashboard_url = get_url_path( - "Superset.dashboard", dashboard_id_or_slug=dashboard.id - ) - # If force, request a screenshot from the workers - current_user = get_current_user() - if kwargs["rison"].get("force", False): - cache_dashboard_thumbnail.delay( - current_user=current_user, - dashboard_id=dashboard.id, - force=True, - ) - return self.response(202, message="OK Async") - # fetch the dashboard screenshot using the current user and cache if set - screenshot = DashboardScreenshot( - dashboard_url, dashboard.digest - ).get_from_cache(cache=thumbnail_cache) - # If the screenshot does not exist, request one from the workers - if not screenshot: - self.incr_stats("async", self.thumbnail.__name__) - cache_dashboard_thumbnail.delay( - current_user=current_user, - dashboard_id=dashboard.id, - force=True, - ) - return self.response(202, message="OK Async") - # If digests - if dashboard.digest != digest: - self.incr_stats("redirect", self.thumbnail.__name__) - return redirect( - url_for( - f"{self.__class__.__name__}.thumbnail", - pk=pk, - digest=dashboard.digest, - ) - ) - self.incr_stats("from_cache", self.thumbnail.__name__) - return Response( - FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True - ) - @expose("//cache_dashboard_screenshot/", methods=("POST",)) @validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"]) @protect() @@ -1172,7 +1069,6 @@ class DashboardRestApi(BaseSupersetModelRestApi): payload = CacheScreenshotSchema().load(request.json) except ValidationError as error: return self.response_400(message=error.messages) - dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters)) if not dashboard: return self.response_404() @@ -1182,7 +1078,7 @@ class DashboardRestApi(BaseSupersetModelRestApi): ) # Don't shrink the image if thumb_size is not specified thumb_size = kwargs["rison"].get("thumb_size") or window_size - + force = kwargs["rison"].get("force", False) dashboard_state: DashboardPermalinkState = { "dataMask": payload.get("dataMask", {}), "activeTabs": payload.get("activeTabs", []), @@ -1197,13 +1093,29 @@ class DashboardRestApi(BaseSupersetModelRestApi): dashboard_url = get_url_path("Superset.dashboard_permalink", key=permalink_key) screenshot_obj = DashboardScreenshot(dashboard_url, dashboard.digest) - cache_key = screenshot_obj.cache_key(window_size, thumb_size, dashboard_state) + cache_key = screenshot_obj.get_cache_key( + window_size, thumb_size, dashboard_state + ) image_url = get_url_path( "DashboardRestApi.screenshot", pk=dashboard.id, digest=cache_key ) + cache_payload = ( + screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload() + ) - def trigger_celery() -> WerkzeugResponse: + def build_response(status_code: int) -> WerkzeugResponse: + return self.response( + status_code, + cache_key=cache_key, + dashboard_url=dashboard_url, + image_url=image_url, + task_updated_at=cache_payload.get_timestamp(), + task_status=cache_payload.get_status(), + ) + + if cache_payload.should_trigger_task(force): logger.info("Triggering screenshot ASYNC") + screenshot_obj.cache.set(cache_key, ScreenshotCachePayload()) cache_dashboard_screenshot.delay( username=get_current_user(), guest_token=( @@ -1213,19 +1125,12 @@ class DashboardRestApi(BaseSupersetModelRestApi): ), dashboard_id=dashboard.id, dashboard_url=dashboard_url, - cache_key=cache_key, - force=False, thumb_size=thumb_size, window_size=window_size, + force=force, ) - return self.response( - 202, - cache_key=cache_key, - dashboard_url=dashboard_url, - image_url=image_url, - ) - - return trigger_celery() + return build_response(202) + return build_response(200) @expose("//screenshot//", methods=("GET",)) @validate_feature_flags(["THUMBNAILS", "ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS"]) @@ -1282,9 +1187,12 @@ class DashboardRestApi(BaseSupersetModelRestApi): # fetch the dashboard screenshot using the current user and cache if set - if img := DashboardScreenshot.get_from_cache_key(thumbnail_cache, digest): + if cache_payload := DashboardScreenshot.get_from_cache_key(digest): + image = cache_payload.get_image() + if not image: + return self.response_404() if download_format == "pdf": - pdf_img = img.getvalue() + pdf_img = image.getvalue() # Convert the screenshot to PDF pdf_data = build_pdf_from_screenshots([pdf_img]) @@ -1296,13 +1204,120 @@ class DashboardRestApi(BaseSupersetModelRestApi): ) if download_format == "png": return Response( - FileWrapper(img), + FileWrapper(image), mimetype="image/png", direct_passthrough=True, ) - return self.response_404() + @expose("//thumbnail//", methods=("GET",)) + @validate_feature_flags(["THUMBNAILS"]) + @protect() + @safe + @rison(thumbnail_query_schema) + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.thumbnail", + log_to_statsd=False, + ) + def thumbnail(self, pk: int, digest: str, **kwargs: Any) -> WerkzeugResponse: + """Compute async or get already computed dashboard thumbnail from cache. + --- + get: + summary: Get dashboard's thumbnail + description: >- + Computes async or get already computed dashboard thumbnail from cache. + parameters: + - in: path + schema: + type: integer + name: pk + - in: path + name: digest + description: A hex digest that makes this dashboard unique + schema: + type: string + responses: + 200: + description: Dashboard thumbnail image + content: + image/*: + schema: + type: string + format: binary + 202: + description: Thumbnail does not exist on cache, fired async to compute + content: + application/json: + schema: + type: object + properties: + message: + type: string + 302: + description: Redirects to the current digest + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + dashboard = cast(Dashboard, self.datamodel.get(pk, self._base_filters)) + if not dashboard: + return self.response_404() + + current_user = get_current_user() + dashboard_url = get_url_path( + "Superset.dashboard", dashboard_id_or_slug=dashboard.id + ) + if dashboard.digest != digest: + self.incr_stats("redirect", self.thumbnail.__name__) + return redirect( + url_for( + f"{self.__class__.__name__}.thumbnail", + pk=pk, + digest=dashboard.digest, + ) + ) + screenshot_obj = DashboardScreenshot(dashboard_url, dashboard.digest) + cache_key = screenshot_obj.get_cache_key() + cache_payload = ( + screenshot_obj.get_from_cache_key(cache_key) or ScreenshotCachePayload() + ) + image_url = get_url_path( + "DashboardRestApi.thumbnail", pk=dashboard.id, digest=cache_key + ) + + if cache_payload.should_trigger_task(): + self.incr_stats("async", self.thumbnail.__name__) + logger.info( + "Triggering thumbnail compute (dashboard id: %s) ASYNC", + str(dashboard.id), + ) + screenshot_obj.cache.set(cache_key, ScreenshotCachePayload()) + cache_dashboard_thumbnail.delay( + current_user=current_user, + dashboard_id=dashboard.id, + force=False, + ) + return self.response( + 202, + cache_key=cache_key, + dashboard_url=dashboard_url, + image_url=image_url, + task_updated_at=cache_payload.get_timestamp(), + task_status=cache_payload.get_status(), + ) + + self.incr_stats("from_cache", self.thumbnail.__name__) + return Response( + FileWrapper(cache_payload.get_image()), + mimetype="image/png", + direct_passthrough=True, + ) + @expose("/favorite_status/", methods=("GET",)) @protect() @safe diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py index 5b18e856c..d0b6230dc 100644 --- a/superset/dashboards/schemas.py +++ b/superset/dashboards/schemas.py @@ -507,6 +507,12 @@ class DashboardCacheScreenshotResponseSchema(Schema): image_url = fields.String( metadata={"description": "The url to fetch the screenshot"} ) + task_status = fields.String( + metadata={"description": "The status of the async screenshot"} + ) + task_updated_at = fields.String( + metadata={"description": "The timestamp of the last change in status"} + ) class CacheScreenshotSchema(Schema): diff --git a/superset/models/slice.py b/superset/models/slice.py index 8795d186e..2469db90a 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -380,9 +380,7 @@ def event_after_chart_changed( _mapper: Mapper, _connection: Connection, target: Slice ) -> None: cache_chart_thumbnail.delay( - current_user=get_current_user(), - chart_id=target.id, - force=True, + current_user=get_current_user(), chart_id=target.id, force=True ) diff --git a/superset/tasks/thumbnails.py b/superset/tasks/thumbnails.py index 3b0b47dbb..8f8597250 100644 --- a/superset/tasks/thumbnails.py +++ b/superset/tasks/thumbnails.py @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) def cache_chart_thumbnail( current_user: Optional[str], chart_id: int, - force: bool = False, + force: bool, window_size: Optional[WindowSize] = None, thumb_size: Optional[WindowSize] = None, ) -> None: @@ -64,10 +64,9 @@ def cache_chart_thumbnail( screenshot = ChartScreenshot(url, chart.digest) screenshot.compute_and_cache( user=user, - cache=thumbnail_cache, - force=force, window_size=window_size, thumb_size=thumb_size, + force=force, ) return None @@ -76,7 +75,7 @@ def cache_chart_thumbnail( def cache_dashboard_thumbnail( current_user: Optional[str], dashboard_id: int, - force: bool = False, + force: bool, thumb_size: Optional[WindowSize] = None, window_size: Optional[WindowSize] = None, ) -> None: @@ -101,10 +100,9 @@ def cache_dashboard_thumbnail( screenshot = DashboardScreenshot(url, dashboard.digest) screenshot.compute_and_cache( user=user, - cache=thumbnail_cache, - force=force, window_size=window_size, thumb_size=thumb_size, + force=force, ) @@ -113,7 +111,7 @@ def cache_dashboard_screenshot( # pylint: disable=too-many-arguments username: str, dashboard_id: int, dashboard_url: str, - force: bool = True, + force: bool, cache_key: Optional[str] = None, guest_token: Optional[GuestToken] = None, thumb_size: Optional[WindowSize] = None, @@ -145,9 +143,8 @@ def cache_dashboard_screenshot( # pylint: disable=too-many-arguments screenshot = DashboardScreenshot(dashboard_url, dashboard.digest) screenshot.compute_and_cache( user=current_user, - cache=thumbnail_cache, - force=force, window_size=window_size, thumb_size=thumb_size, cache_key=cache_key, + force=force, ) diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index 1557bc283..86f5a94ce 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -17,12 +17,14 @@ from __future__ import annotations import logging +from datetime import datetime +from enum import Enum from io import BytesIO from typing import TYPE_CHECKING from flask import current_app -from superset import feature_flag_manager +from superset import app, feature_flag_manager, thumbnail_cache from superset.dashboards.permalink.types import DashboardPermalinkState from superset.extensions import event_logger from superset.utils.hashing import md5_sha_from_dict @@ -54,6 +56,70 @@ if TYPE_CHECKING: from flask_caching import Cache +class StatusValues(Enum): + PENDING = "Pending" + COMPUTING = "Computing" + UPDATED = "Updated" + ERROR = "Error" + + +class ScreenshotCachePayload: + def __init__(self, image: bytes | None = None): + self._image = image + self._timestamp = datetime.now().isoformat() + self.status = StatusValues.PENDING + if image: + self.status = StatusValues.UPDATED + + def update_timestamp(self) -> None: + self._timestamp = datetime.now().isoformat() + + def pending(self) -> None: + self.update_timestamp() + self._image = None + self.status = StatusValues.PENDING + + def computing(self) -> None: + self.update_timestamp() + self._image = None + self.status = StatusValues.COMPUTING + + def update(self, image: bytes) -> None: + self.update_timestamp() + self.status = StatusValues.UPDATED + self._image = image + + def error( + self, + ) -> None: + self.update_timestamp() + self.status = StatusValues.ERROR + + def get_image(self) -> BytesIO | None: + if not self._image: + return None + return BytesIO(self._image) + + def get_timestamp(self) -> str: + return self._timestamp + + def get_status(self) -> str: + return self.status.value + + def is_error_cache_ttl_expired(self) -> bool: + error_cache_ttl = app.config["THUMBNAIL_ERROR_CACHE_TTL"] + return ( + datetime.now() - datetime.fromisoformat(self.get_timestamp()) + ).total_seconds() > error_cache_ttl + + def should_trigger_task(self, force: bool = False) -> bool: + return ( + force + or self.status == StatusValues.PENDING + or (self.status == StatusValues.ERROR and self.is_error_cache_ttl_expired()) + ) + + class BaseScreenshot: driver_type = current_app.config["WEBDRIVER_TYPE"] url: str @@ -63,6 +129,7 @@ class BaseScreenshot: element: str = "" window_size: WindowSize = DEFAULT_SCREENSHOT_WINDOW_SIZE thumb_size: WindowSize = DEFAULT_SCREENSHOT_THUMBNAIL_SIZE + cache: Cache = thumbnail_cache def __init__(self, url: str, digest: str | None): self.digest = digest @@ -75,7 +142,14 @@ class BaseScreenshot: return WebDriverPlaywright(self.driver_type, window_size) return WebDriverSelenium(self.driver_type, window_size) - def cache_key( + def get_screenshot( + self, user: User, window_size: WindowSize | None = None + ) -> bytes | None: + driver = self.driver(window_size) + self.screenshot = driver.get_screenshot(self.url, self.element, user) + return self.screenshot + + def get_cache_key( self, window_size: bool | WindowSize | None = None, thumb_size: bool | WindowSize | None = None, @@ -91,69 +165,35 @@ class BaseScreenshot: } return md5_sha_from_dict(args) - def get_screenshot( - self, user: User, window_size: WindowSize | None = None - ) -> bytes | None: - driver = self.driver(window_size) - with event_logger.log_context("screenshot", screenshot_url=self.url): - self.screenshot = driver.get_screenshot(self.url, self.element, user) - return self.screenshot - - def get( - self, - user: User = None, - cache: Cache = None, - thumb_size: WindowSize | None = None, - ) -> BytesIO | None: - """ - Get thumbnail screenshot has BytesIO from cache or fetch - - :param user: None to use current user or User Model to login and fetch - :param cache: The cache to use - :param thumb_size: Override thumbnail site - """ - payload: bytes | None = None - cache_key = self.cache_key(self.window_size, thumb_size) - if cache: - payload = cache.get(cache_key) - if not payload: - payload = self.compute_and_cache( - user=user, thumb_size=thumb_size, cache=cache - ) - else: - logger.info("Loaded thumbnail from cache: %s", cache_key) - if payload: - return BytesIO(payload) - return None - def get_from_cache( self, - cache: Cache, window_size: WindowSize | None = None, thumb_size: WindowSize | None = None, - ) -> BytesIO | None: - cache_key = self.cache_key(window_size, thumb_size) - return self.get_from_cache_key(cache, cache_key) + ) -> ScreenshotCachePayload | None: + cache_key = self.get_cache_key(window_size, thumb_size) + return self.get_from_cache_key(cache_key) - @staticmethod - def get_from_cache_key(cache: Cache, cache_key: str) -> BytesIO | None: + @classmethod + def get_from_cache_key(cls, cache_key: str) -> ScreenshotCachePayload | None: logger.info("Attempting to get from cache: %s", cache_key) - if payload := cache.get(cache_key): - return BytesIO(payload) + if payload := cls.cache.get(cache_key): + # for backwards compatability, byte objects should be converted + if not isinstance(payload, ScreenshotCachePayload): + payload = ScreenshotCachePayload(payload) + return payload logger.info("Failed at getting from cache: %s", cache_key) return None def compute_and_cache( # pylint: disable=too-many-arguments self, + force: bool, user: User = None, window_size: WindowSize | None = None, thumb_size: WindowSize | None = None, - cache: Cache = None, - force: bool = True, cache_key: str | None = None, - ) -> bytes | None: + ) -> None: """ - Fetches the screenshot, computes the thumbnail and caches the result + Computes the thumbnail and caches the result :param user: If no user is given will use the current context :param cache: The cache to keep the thumbnail payload @@ -162,40 +202,46 @@ class BaseScreenshot: :param force: Will force the computation even if it's already cached :return: Image payload """ - cache_key = cache_key or self.cache_key(window_size, thumb_size) + cache_key = cache_key or self.get_cache_key(window_size, thumb_size) + cache_payload = self.get_from_cache_key(cache_key) or ScreenshotCachePayload() + if ( + cache_payload.status in [StatusValues.COMPUTING, StatusValues.UPDATED] + and not force + ): + logger.info( + "Skipping compute - already processed for thumbnail: %s", cache_key + ) + return + window_size = window_size or self.window_size thumb_size = thumb_size or self.thumb_size - if not force and cache and cache.get(cache_key): - logger.info("Thumb already cached, skipping...") - return None logger.info("Processing url for thumbnail: %s", cache_key) - - payload = None - + cache_payload.computing() + self.cache.set(cache_key, cache_payload) + image = None # Assuming all sorts of things can go wrong with Selenium try: - with event_logger.log_context( - f"screenshot.compute.{self.thumbnail_type}", force=force - ): - payload = self.get_screenshot(user=user, window_size=window_size) + logger.info("trying to generate screenshot") + with event_logger.log_context(f"screenshot.compute.{self.thumbnail_type}"): + image = self.get_screenshot(user=user, window_size=window_size) except Exception as ex: # pylint: disable=broad-except logger.warning("Failed at generating thumbnail %s", ex, exc_info=True) - - if payload and window_size != thumb_size: + cache_payload.error() + if image and window_size != thumb_size: try: - payload = self.resize_image(payload, thumb_size=thumb_size) + image = self.resize_image(image, thumb_size=thumb_size) except Exception as ex: # pylint: disable=broad-except logger.warning("Failed at resizing thumbnail %s", ex, exc_info=True) - payload = None + cache_payload.error() + image = None - if payload: + if image: logger.info("Caching thumbnail: %s", cache_key) - with event_logger.log_context( - f"screenshot.cache.{self.thumbnail_type}", force=force - ): - cache.set(cache_key, payload) - logger.info("Done caching thumbnail") - return payload + with event_logger.log_context(f"screenshot.cache.{self.thumbnail_type}"): + cache_payload.update(image) + self.cache.set(cache_key, cache_payload) + logger.info("Updated thumbnail cache; Status: %s", cache_payload.get_status()) + return @classmethod def resize_image( @@ -265,7 +311,7 @@ class DashboardScreenshot(BaseScreenshot): self.window_size = window_size or DEFAULT_DASHBOARD_WINDOW_SIZE self.thumb_size = thumb_size or DEFAULT_DASHBOARD_THUMBNAIL_SIZE - def cache_key( + def get_cache_key( self, window_size: bool | WindowSize | None = None, thumb_size: bool | WindowSize | None = None, diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py index c8e46581e..f9a7899ff 100644 --- a/superset/utils/webdriver.py +++ b/superset/utils/webdriver.py @@ -380,7 +380,7 @@ class WebDriverSelenium(WebDriverProxy): return error_messages - def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: + def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: # noqa: C901 driver = self.auth(user) driver.set_window_size(*self._window) driver.get(url) @@ -411,6 +411,7 @@ class WebDriverSelenium(WebDriverProxy): ) ) except TimeoutException: + logger.info("Timeout Exception caught") # Fallback to allow a screenshot of an empty dashboard try: WebDriverWait(driver, 0).until( @@ -461,18 +462,23 @@ class WebDriverSelenium(WebDriverProxy): ) img = element.screenshot_as_png + except Exception as ex: + logger.warning("exception in webdriver", exc_info=ex) + raise except TimeoutException: # raise again for the finally block, but handled above - pass + raise except StaleElementReferenceException: logger.exception( "Selenium got a stale element while requesting url %s", url, ) + raise except WebDriverException: logger.exception( "Encountered an unexpected error when requesting url %s", url ) + raise finally: self.destroy(driver, current_app.config["SCREENSHOT_SELENIUM_RETRIES"]) return img diff --git a/tests/integration_tests/cli_tests.py b/tests/integration_tests/cli_tests.py index 048612a08..a00cf2981 100644 --- a/tests/integration_tests/cli_tests.py +++ b/tests/integration_tests/cli_tests.py @@ -319,9 +319,5 @@ def test_compute_thumbnails(thumbnail_mock, app_context, fs): ["-d", "-i", dashboard.id], ) - thumbnail_mock.assert_called_with( - None, - dashboard.id, - force=False, - ) + thumbnail_mock.assert_called_with(None, dashboard.id, force=False) assert response.exit_code == 0 diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 2bab17d05..430cdf8fe 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -37,6 +37,7 @@ from superset.reports.models import ReportSchedule, ReportScheduleType from superset.models.slice import Slice from superset.tags.models import Tag, TaggedObject, TagType, ObjectType from superset.utils.core import backend, override_user +from superset.utils.screenshots import ScreenshotCachePayload from superset.utils import json from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin @@ -3069,13 +3070,15 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas @pytest.mark.usefixtures("create_dashboard_with_tag") @patch("superset.dashboards.api.cache_dashboard_screenshot") @patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key") - def test_screenshot_success_png(self, mock_get_cache, mock_cache_task): + def test_screenshot_success_png(self, mock_get_from_cache_key, mock_cache_task): """ Validate screenshot returns png """ self.login(ADMIN_USERNAME) mock_cache_task.return_value = None - mock_get_cache.return_value = BytesIO(b"fake image data") + mock_get_from_cache_key.return_value = ScreenshotCachePayload( + b"fake image data" + ) dashboard = ( db.session.query(Dashboard) @@ -3083,7 +3086,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas .first() ) cache_resp = self._cache_screenshot(dashboard.id) - assert cache_resp.status_code == 202 + assert cache_resp.status_code == 200 cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] response = self._get_screenshot(dashboard.id, cache_key, "png") @@ -3091,20 +3094,29 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas assert response.mimetype == "image/png" assert response.data == b"fake image data" + mock_get_from_cache_key.return_value = ScreenshotCachePayload() + cache_resp = self._cache_screenshot(dashboard.id) + assert cache_resp.status_code == 202 + @with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True) @pytest.mark.usefixtures("create_dashboard_with_tag") @patch("superset.dashboards.api.cache_dashboard_screenshot") @patch("superset.dashboards.api.build_pdf_from_screenshots") @patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key") def test_screenshot_success_pdf( - self, mock_get_from_cache, mock_build_pdf, mock_cache_task + self, + mock_get_from_cache_key, + mock_build_pdf, + mock_cache_task, ): """ Validate screenshot can return pdf. """ self.login(ADMIN_USERNAME) mock_cache_task.return_value = None - mock_get_from_cache.return_value = BytesIO(b"fake image data") + mock_get_from_cache_key.return_value = ScreenshotCachePayload( + b"fake image data" + ) mock_build_pdf.return_value = b"fake pdf data" dashboard = ( @@ -3113,7 +3125,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas .first() ) cache_resp = self._cache_screenshot(dashboard.id) - assert cache_resp.status_code == 202 + assert cache_resp.status_code == 200 cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] response = self._get_screenshot(dashboard.id, cache_key, "pdf") @@ -3121,6 +3133,10 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas assert response.mimetype == "application/pdf" assert response.data == b"fake pdf data" + mock_get_from_cache_key.return_value = ScreenshotCachePayload() + cache_resp = self._cache_screenshot(dashboard.id) + assert cache_resp.status_code == 202 + @with_feature_flags(THUMBNAILS=True, ENABLE_DASHBOARD_SCREENSHOT_ENDPOINTS=True) @pytest.mark.usefixtures("create_dashboard_with_tag") @patch("superset.dashboards.api.cache_dashboard_screenshot") @@ -3153,10 +3169,12 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas @pytest.mark.usefixtures("create_dashboard_with_tag") @patch("superset.dashboards.api.cache_dashboard_screenshot") @patch("superset.dashboards.api.DashboardScreenshot.get_from_cache_key") - def test_screenshot_invalid_download_format(self, mock_get_cache, mock_cache_task): + def test_screenshot_invalid_download_format( + self, mock_get_from_cache_key, mock_cache_task + ): self.login(ADMIN_USERNAME) mock_cache_task.return_value = None - mock_get_cache.return_value = BytesIO(b"fake png data") + mock_get_from_cache_key.return_value = ScreenshotCachePayload(b"fake png data") dashboard = ( db.session.query(Dashboard) @@ -3165,9 +3183,13 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas ) cache_resp = self._cache_screenshot(dashboard.id) - assert cache_resp.status_code == 202 + assert cache_resp.status_code == 200 cache_key = json.loads(cache_resp.data.decode("utf-8"))["cache_key"] + mock_get_from_cache_key.return_value = ScreenshotCachePayload() + cache_resp = self._cache_screenshot(dashboard.id) + assert cache_resp.status_code == 202 + response = self._get_screenshot(dashboard.id, cache_key, "invalid") assert response.status_code == 404 diff --git a/tests/integration_tests/thumbnails_tests.py b/tests/integration_tests/thumbnails_tests.py index e808858fb..a3d4e4b3f 100644 --- a/tests/integration_tests/thumbnails_tests.py +++ b/tests/integration_tests/thumbnails_tests.py @@ -18,7 +18,6 @@ # from superset.models.dashboard import Dashboard import urllib.request -from io import BytesIO from unittest import skipUnless from unittest.mock import ANY, call, MagicMock, patch @@ -32,7 +31,11 @@ from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.tasks.types import ExecutorType, FixedExecutor from superset.utils import json -from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot +from superset.utils.screenshots import ( + ChartScreenshot, + DashboardScreenshot, + ScreenshotCachePayload, +) from superset.utils.urls import get_url_path from superset.utils.webdriver import WebDriverSelenium from tests.integration_tests.base_tests import SupersetTestCase @@ -287,14 +290,14 @@ class TestThumbnails(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") - def test_get_async_dashboard_not_allowed(self): + def test_get_async_dashboard_created(self): """ Thumbnails: Simple get async dashboard not allowed """ self.login(ADMIN_USERNAME) _, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) rv = self.client.get(thumbnail_url) - assert rv.status_code == 404 + assert rv.status_code == 202 @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(THUMBNAILS=True) @@ -370,7 +373,9 @@ class TestThumbnails(SupersetTestCase): Thumbnails: Simple get chart with wrong digest """ with patch.object( - ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image) + ChartScreenshot, + "get_from_cache", + return_value=ScreenshotCachePayload(self.mock_image), ): self.login(ADMIN_USERNAME) id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL) @@ -385,7 +390,9 @@ class TestThumbnails(SupersetTestCase): Thumbnails: Simple get cached dashboard screenshot """ with patch.object( - DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image) + DashboardScreenshot, + "get_from_cache_key", + return_value=ScreenshotCachePayload(self.mock_image), ): self.login(ADMIN_USERNAME) _, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) @@ -400,7 +407,9 @@ class TestThumbnails(SupersetTestCase): Thumbnails: Simple get cached chart screenshot """ with patch.object( - ChartScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image) + ChartScreenshot, + "get_from_cache_key", + return_value=ScreenshotCachePayload(self.mock_image), ): self.login(ADMIN_USERNAME) id_, thumbnail_url = self._get_id_and_thumbnail_url(CHART_URL) @@ -415,7 +424,9 @@ class TestThumbnails(SupersetTestCase): Thumbnails: Simple get dashboard with wrong digest """ with patch.object( - DashboardScreenshot, "get_from_cache", return_value=BytesIO(self.mock_image) + DashboardScreenshot, + "get_from_cache", + return_value=ScreenshotCachePayload(self.mock_image), ): self.login(ADMIN_USERNAME) id_, thumbnail_url = self._get_id_and_thumbnail_url(DASHBOARD_URL) diff --git a/tests/unit_tests/utils/screenshot_test.py b/tests/unit_tests/utils/screenshot_test.py new file mode 100644 index 000000000..5d29d829a --- /dev/null +++ b/tests/unit_tests/utils/screenshot_test.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel, unused-argument + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from superset.utils.hashing import md5_sha_from_dict +from superset.utils.screenshots import ( + BaseScreenshot, + ScreenshotCachePayload, + StatusValues, +) + +BASE_SCREENSHOT_PATH = "superset.utils.screenshots.BaseScreenshot" + + +class MockCache: + """A class to manage screenshot cache.""" + + def __init__(self): + self._cache = None # Store the cached value + + def set(self, _key, value): + """Set the cache with a new value.""" + self._cache = value + + def get(self, _key): + """Get the cached value.""" + return self._cache + + +@pytest.fixture +def mock_user(): + """Fixture to create a mock user.""" + mock_user = MagicMock() + mock_user.id = 1 + return mock_user + + +@pytest.fixture +def screenshot_obj(): + """Fixture to create a BaseScreenshot object.""" + url = "http://example.com" + digest = "sample_digest" + return BaseScreenshot(url, digest) + + +def test_get_screenshot(mocker: MockerFixture, screenshot_obj): + """Get screenshot should return a Bytes object""" + fake_bytes = b"fake_screenshot_data" + driver = mocker.patch(BASE_SCREENSHOT_PATH + ".driver") + driver.return_value.get_screenshot.return_value = fake_bytes + screenshot_data = screenshot_obj.get_screenshot(mock_user) + assert screenshot_data == fake_bytes + + +def test_get_cache_key(screenshot_obj): + """Test get_cache_key method""" + expected_cache_key = md5_sha_from_dict( + { + "thumbnail_type": "", + "digest": screenshot_obj.digest, + "type": "thumb", + "window_size": screenshot_obj.window_size, + "thumb_size": screenshot_obj.thumb_size, + } + ) + cache_key = screenshot_obj.get_cache_key() + assert cache_key == expected_cache_key + + +def test_get_from_cache_key(mocker: MockerFixture, screenshot_obj): + """get_from_cache_key should always return a ScreenshotCachePayload Object""" + # backwards compatability test for retrieving plain bytes + fake_bytes = b"fake_screenshot_data" + BaseScreenshot.cache = MockCache() + BaseScreenshot.cache.set("key", fake_bytes) + cache_payload = screenshot_obj.get_from_cache_key("key") + assert isinstance(cache_payload, ScreenshotCachePayload) + assert cache_payload._image == fake_bytes # pylint: disable=protected-access + + +class TestComputeAndCache: + def _setup_compute_and_cache(self, mocker: MockerFixture, screenshot_obj): + """Helper method to handle the common setup for the tests.""" + # Patch the methods + get_from_cache_key = mocker.patch( + BASE_SCREENSHOT_PATH + ".get_from_cache_key", return_value=None + ) + get_screenshot = mocker.patch( + BASE_SCREENSHOT_PATH + ".get_screenshot", return_value=b"new_image_data" + ) + resize_image = mocker.patch( + BASE_SCREENSHOT_PATH + ".resize_image", return_value=b"resized_image_data" + ) + BaseScreenshot.cache = MockCache() + return { + "get_from_cache_key": get_from_cache_key, + "get_screenshot": get_screenshot, + "resize_image": resize_image, + } + + def test_happy_path(self, mocker: MockerFixture, screenshot_obj): + self._setup_compute_and_cache(mocker, screenshot_obj) + screenshot_obj.compute_and_cache(force=False) + cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key") + assert cache_payload.status == StatusValues.UPDATED + + def test_screenshot_error(self, mocker: MockerFixture, screenshot_obj): + mocks = self._setup_compute_and_cache(mocker, screenshot_obj) + get_screenshot: MagicMock = mocks.get("get_screenshot") + get_screenshot.side_effect = Exception + screenshot_obj.compute_and_cache(force=False) + cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key") + assert cache_payload.status == StatusValues.ERROR + + def test_resize_error(self, mocker: MockerFixture, screenshot_obj): + mocks = self._setup_compute_and_cache(mocker, screenshot_obj) + resize_image: MagicMock = mocks.get("resize_image") + resize_image.side_effect = Exception + screenshot_obj.compute_and_cache(force=False) + cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key") + assert cache_payload.status == StatusValues.ERROR + + def test_skips_if_computing(self, mocker: MockerFixture, screenshot_obj): + mocks = self._setup_compute_and_cache(mocker, screenshot_obj) + cached_value = ScreenshotCachePayload() + cached_value.computing() + get_from_cache_key = mocks.get("get_from_cache_key") + get_from_cache_key.return_value = cached_value + + # Ensure that it skips when thumbnail status is computing + screenshot_obj.compute_and_cache(force=False) + get_screenshot = mocks.get("get_screenshot") + get_screenshot.assert_not_called() + + # Ensure that it processes when force = True + screenshot_obj.compute_and_cache(force=True) + get_screenshot.assert_called_once() + cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key") + assert cache_payload.status == StatusValues.UPDATED + + def test_skips_if_updated(self, mocker: MockerFixture, screenshot_obj): + mocks = self._setup_compute_and_cache(mocker, screenshot_obj) + cached_value = ScreenshotCachePayload(image=b"initial_value") + get_from_cache_key = mocks.get("get_from_cache_key") + get_from_cache_key.return_value = cached_value + + # Ensure that it skips when thumbnail status is updated + window_size = thumb_size = (10, 10) + screenshot_obj.compute_and_cache( + force=False, window_size=window_size, thumb_size=thumb_size + ) + get_screenshot = mocks.get("get_screenshot") + get_screenshot.assert_not_called() + + # Ensure that it processes when force = True + screenshot_obj.compute_and_cache( + force=True, window_size=window_size, thumb_size=thumb_size + ) + get_screenshot.assert_called_once() + cache_payload: ScreenshotCachePayload = screenshot_obj.cache.get("key") + assert cache_payload._image != b"initial_value" + + def test_resize(self, mocker: MockerFixture, screenshot_obj): + mocks = self._setup_compute_and_cache(mocker, screenshot_obj) + window_size = thumb_size = (10, 10) + resize_image: MagicMock = mocks.get("resize_image") + screenshot_obj.compute_and_cache( + force=False, window_size=window_size, thumb_size=thumb_size + ) + resize_image.assert_not_called() + screenshot_obj.compute_and_cache( + force=False, window_size=(1, 1), thumb_size=thumb_size + ) + resize_image.assert_called_once()