diff --git a/superset/dashboards/permalink/commands/create.py b/superset/dashboards/permalink/commands/create.py index b8cbdbd3a..75cb4f574 100644 --- a/superset/dashboards/permalink/commands/create.py +++ b/superset/dashboards/permalink/commands/create.py @@ -22,13 +22,19 @@ from superset.dashboards.dao import DashboardDAO from superset.dashboards.permalink.commands.base import BaseDashboardPermalinkCommand from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError from superset.dashboards.permalink.types import DashboardPermalinkState -from superset.key_value.commands.create import CreateKeyValueCommand -from superset.key_value.utils import encode_permalink_key +from superset.key_value.commands.upsert import UpsertKeyValueCommand +from superset.key_value.utils import encode_permalink_key, get_deterministic_uuid +from superset.utils.core import get_user_id logger = logging.getLogger(__name__) class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand): + """ + Get or create a permalink key for the given dashboard in certain state. + Will reuse the key for the same user and dashboard state. + """ + def __init__( self, dashboard_id: str, @@ -45,12 +51,13 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand): "dashboardId": self.dashboard_id, "state": self.state, } - key = CreateKeyValueCommand( + user_id = get_user_id() + key = UpsertKeyValueCommand( resource=self.resource, + key=get_deterministic_uuid(self.salt, (user_id, value)), value=value, ).run() - if key.id is None: - raise DashboardPermalinkCreateFailedError("Unexpected missing key id") + assert key.id # for type checks return encode_permalink_key(key=key.id, salt=self.salt) except SQLAlchemyError as ex: logger.exception("Error running create command") diff --git a/superset/key_value/commands/upsert.py b/superset/key_value/commands/upsert.py index 4bb64aa24..80b025559 100644 --- a/superset/key_value/commands/upsert.py +++ b/superset/key_value/commands/upsert.py @@ -63,7 +63,7 @@ class UpsertKeyValueCommand(BaseCommand): self.value = value self.expires_on = expires_on - def run(self) -> Optional[Key]: + def run(self) -> Key: try: return self.upsert() except SQLAlchemyError as ex: @@ -74,7 +74,7 @@ class UpsertKeyValueCommand(BaseCommand): def validate(self) -> None: pass - def upsert(self) -> Optional[Key]: + def upsert(self) -> Key: filter_ = get_filter(self.resource, self.key) entry: KeyValueEntry = ( db.session.query(KeyValueEntry) diff --git a/superset/key_value/utils.py b/superset/key_value/utils.py index b2e8e729b..2468618a8 100644 --- a/superset/key_value/utils.py +++ b/superset/key_value/utils.py @@ -18,14 +18,15 @@ from __future__ import annotations from hashlib import md5 from secrets import token_urlsafe -from typing import Union -from uuid import UUID +from typing import Any, Union +from uuid import UUID, uuid3 import hashids from flask_babel import gettext as _ from superset.key_value.exceptions import KeyValueParseKeyError from superset.key_value.types import KeyValueFilter, KeyValueResource +from superset.utils.core import json_dumps_w_dates HASHIDS_MIN_LENGTH = 11 @@ -63,3 +64,9 @@ def get_uuid_namespace(seed: str) -> UUID: md5_obj = md5() md5_obj.update(seed.encode("utf-8")) return UUID(md5_obj.hexdigest()) + + +def get_deterministic_uuid(namespace: str, payload: Any) -> UUID: + """Get a deterministic UUID (uuid3) from a salt and a JSON-serializable payload.""" + payload_str = json_dumps_w_dates(payload, sort_keys=True) + return uuid3(get_uuid_namespace(namespace), payload_str) diff --git a/superset/utils/core.py b/superset/utils/core.py index 557d14c81..5ce52f9f4 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -608,8 +608,9 @@ def json_int_dttm_ser(obj: Any) -> float: return obj -def json_dumps_w_dates(payload: Dict[Any, Any]) -> str: - return json.dumps(payload, default=json_int_dttm_ser) +def json_dumps_w_dates(payload: Dict[Any, Any], sort_keys: bool = False) -> str: + """Dumps payload to JSON with Datetime objects properly converted""" + return json.dumps(payload, default=json_int_dttm_ser, sort_keys=sort_keys) def error_msg_from_exception(ex: Exception) -> str: diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py index 12d758d5e..036b42857 100644 --- a/tests/integration_tests/dashboards/permalink/api_tests.py +++ b/tests/integration_tests/dashboards/permalink/api_tests.py @@ -71,11 +71,17 @@ def test_post(client, dashboard_id: int, permalink_salt: str) -> None: login(client, "admin") resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE) assert resp.status_code == 201 - data = json.loads(resp.data.decode("utf-8")) + data = resp.json key = data["key"] url = data["url"] assert key in url id_ = decode_permalink_id(key, permalink_salt) + + assert ( + data + == client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE).json + ), "Should always return the same permalink key for the same payload" + db.session.query(KeyValueEntry).filter_by(id=id_).delete() db.session.commit() @@ -98,12 +104,12 @@ def test_post_invalid_schema(client, dashboard_id: int): def test_get(client, dashboard_id: int, permalink_salt: str): login(client, "admin") - resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE) - data = json.loads(resp.data.decode("utf-8")) - key = data["key"] + key = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE).json[ + "key" + ] resp = client.get(f"api/v1/dashboard/permalink/{key}") assert resp.status_code == 200 - result = json.loads(resp.data.decode("utf-8")) + result = resp.json assert result["dashboardId"] == str(dashboard_id) assert result["state"] == STATE id_ = decode_permalink_id(key, permalink_salt)