chore(key-value): convert command to dao (#29344)

This commit is contained in:
Ville Brofeldt 2024-07-01 20:22:11 +03:00 committed by GitHub
parent 0cf676b574
commit 7d6e933348
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 867 additions and 1162 deletions

View File

@ -19,9 +19,10 @@ from functools import partial
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
from superset.commands.key_value.upsert import UpsertKeyValueCommand
from superset.daos.dashboard import DashboardDAO
from superset.daos.key_value import KeyValueDAO
from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError
from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.key_value.exceptions import (
@ -70,14 +71,15 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
"state": self.state,
}
user_id = get_user_id()
key = UpsertKeyValueCommand(
entry = KeyValueDAO.upsert_entry(
resource=self.resource,
key=get_deterministic_uuid(self.salt, (user_id, value)),
value=value,
codec=self.codec,
).run()
assert key.id # for type checks
return encode_permalink_key(key=key.id, salt=self.salt)
)
db.session.flush()
assert entry.id # for type checks
return encode_permalink_key(key=entry.id, salt=self.salt)
def validate(self) -> None:
pass

View File

@ -21,8 +21,8 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.commands.dashboard.exceptions import DashboardNotFoundError
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
from superset.commands.key_value.get import GetKeyValueCommand
from superset.daos.dashboard import DashboardDAO
from superset.daos.key_value import KeyValueDAO
from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError
from superset.dashboards.permalink.types import DashboardPermalinkValue
from superset.key_value.exceptions import (
@ -43,12 +43,7 @@ class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
self.validate()
try:
key = decode_permalink_id(self.key, salt=self.salt)
command = GetKeyValueCommand(
resource=self.resource,
key=key,
codec=self.codec,
)
value: Optional[DashboardPermalinkValue] = command.run()
value = KeyValueDAO.get_value(self.resource, key, self.codec)
if value:
DashboardDAO.get_by_id_or_slug(value["dashboardId"])
return value

View File

@ -14,3 +14,28 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import uuid
from typing import Any
from flask import current_app
from superset.commands.base import BaseCommand
from superset.distributed_lock.utils import get_key
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
class BaseDistributedLockCommand(BaseCommand):
key: uuid.UUID
codec = JsonKeyValueCodec()
resource = KeyValueResource.LOCK
def __init__(self, namespace: str, params: dict[str, Any] | None = None):
self.key = get_key(namespace, **(params or {}))
def validate(self) -> None:
pass

View File

@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime, timedelta
from functools import partial
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
from superset.daos.key_value import KeyValueDAO
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.key_value.exceptions import (
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
)
from superset.key_value.types import KeyValueResource
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
class CreateDistributedLock(BaseDistributedLockCommand):
lock_expiration = timedelta(seconds=30)
def validate(self) -> None:
pass
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
SQLAlchemyError,
),
reraise=CreateKeyValueDistributedLockFailedException,
),
)
def run(self) -> None:
KeyValueDAO.delete_expired_entries(self.resource)
KeyValueDAO.create_entry(
resource=KeyValueResource.LOCK,
value={"value": True},
codec=self.codec,
key=self.key,
expires_on=datetime.now() + self.lock_expiration,
)

View File

@ -14,49 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from functools import partial
from sqlalchemy import and_
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
from superset.daos.key_value import KeyValueDAO
from superset.exceptions import DeleteKeyValueDistributedLockFailedException
from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
class DeleteExpiredKeyValueCommand(BaseCommand):
resource: KeyValueResource
def __init__(self, resource: KeyValueResource):
"""
Delete all expired key-value pairs
:param resource: the resource (dashboard, chart etc)
:return: was the entry deleted or not
"""
self.resource = resource
@transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError))
def run(self) -> None:
self.delete_expired()
class DeleteDistributedLock(BaseDistributedLockCommand):
def validate(self) -> None:
pass
def delete_expired(self) -> None:
(
db.session.query(KeyValueEntry)
.filter(
and_(
KeyValueEntry.resource == self.resource.value,
KeyValueEntry.expires_on <= datetime.now(),
)
)
.delete()
)
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueDeleteFailedError,
SQLAlchemyError,
),
reraise=DeleteKeyValueDistributedLockFailedException,
),
)
def run(self) -> None:
KeyValueDAO.delete_entry(self.resource, self.key)

View File

@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import cast
from flask import current_app
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
from superset.daos.key_value import KeyValueDAO
from superset.distributed_lock.types import LockValue
logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
class GetDistributedLock(BaseDistributedLockCommand):
def validate(self) -> None:
pass
def run(self) -> LockValue | None:
entry = KeyValueDAO.get_entry(
resource=self.resource,
key=self.key,
)
if not entry or entry.is_expired():
return None
return cast(LockValue, self.codec.decode(entry.value))

View File

@ -20,8 +20,9 @@ from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
from superset.commands.key_value.create import CreateKeyValueCommand
from superset.daos.key_value import KeyValueDAO
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
from superset.explore.utils import check_access as check_chart_access
from superset.key_value.exceptions import (
@ -65,15 +66,12 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
"datasource": self.datasource,
"state": self.state,
}
command = CreateKeyValueCommand(
resource=self.resource,
value=value,
codec=self.codec,
)
key = command.run()
if key.id is None:
entry = KeyValueDAO.create_entry(self.resource, value, self.codec)
db.session.flush()
key = entry.id
if key is None:
raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
return encode_permalink_key(key=key.id, salt=self.salt)
return encode_permalink_key(key=key, salt=self.salt)
def validate(self) -> None:
pass

View File

@ -21,7 +21,7 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.commands.dataset.exceptions import DatasetNotFoundError
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
from superset.commands.key_value.get import GetKeyValueCommand
from superset.daos.key_value import KeyValueDAO
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.explore.permalink.types import ExplorePermalinkValue
from superset.explore.utils import check_access as check_chart_access
@ -44,11 +44,7 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
self.validate()
try:
key = decode_permalink_id(self.key, salt=self.salt)
value: Optional[ExplorePermalinkValue] = GetKeyValueCommand(
resource=self.resource,
key=key,
codec=self.codec,
).run()
value = KeyValueDAO.get_value(self.resource, key, self.codec)
if value:
chart_id: Optional[int] = value.get("chartId")
# keep this backward compatible for old permalinks

View File

@ -1,102 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional, Union
from uuid import UUID
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
class CreateKeyValueCommand(BaseCommand):
resource: KeyValueResource
value: Any
codec: KeyValueCodec
key: Optional[Union[int, UUID]]
expires_on: Optional[datetime]
def __init__( # pylint: disable=too-many-arguments
self,
resource: KeyValueResource,
value: Any,
codec: KeyValueCodec,
key: Optional[Union[int, UUID]] = None,
expires_on: Optional[datetime] = None,
):
"""
Create a new key-value pair
:param resource: the resource (dashboard, chart etc)
:param value: the value to persist in the key-value store
:param codec: codec used to encode the value
:param key: id of entry (autogenerated if undefined)
:param expires_on: entry expiration time
:
"""
self.resource = resource
self.value = value
self.codec = codec
self.key = key
self.expires_on = expires_on
@transaction(on_error=partial(on_error, reraise=KeyValueCreateFailedError))
def run(self) -> Key:
"""
Persist the value
:return: the key associated with the persisted value
"""
return self.create()
def validate(self) -> None:
pass
def create(self) -> Key:
try:
value = self.codec.encode(self.value)
except Exception as ex:
raise KeyValueCreateFailedError("Unable to encode value") from ex
entry = KeyValueEntry(
resource=self.resource.value,
value=value,
created_on=datetime.now(),
created_by_fk=get_user_id(),
expires_on=self.expires_on,
)
if self.key is not None:
try:
if isinstance(self.key, UUID):
entry.uuid = self.key
else:
entry.id = self.key
except ValueError as ex:
raise KeyValueCreateFailedError() from ex
db.session.add(entry)
db.session.flush()
return Key(id=entry.id, uuid=entry.uuid)

View File

@ -1,63 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Union
from uuid import UUID
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource
from superset.key_value.utils import get_filter
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
class DeleteKeyValueCommand(BaseCommand):
key: Union[int, UUID]
resource: KeyValueResource
def __init__(self, resource: KeyValueResource, key: Union[int, UUID]):
"""
Delete a key-value pair
:param resource: the resource (dashboard, chart etc)
:param key: the key to delete
:return: was the entry deleted or not
"""
self.resource = resource
self.key = key
@transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError))
def run(self) -> bool:
return self.delete()
def validate(self) -> None:
pass
def delete(self) -> bool:
if (
entry := db.session.query(KeyValueEntry)
.filter_by(**get_filter(self.resource, self.key))
.first()
):
db.session.delete(entry)
return True
return False

View File

@ -1,71 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Optional, Union
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueGetFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueCodec, KeyValueResource
from superset.key_value.utils import get_filter
logger = logging.getLogger(__name__)
class GetKeyValueCommand(BaseCommand):
resource: KeyValueResource
key: Union[int, UUID]
codec: KeyValueCodec
def __init__(
self,
resource: KeyValueResource,
key: Union[int, UUID],
codec: KeyValueCodec,
):
"""
Retrieve a key value entry
:param resource: the resource (dashboard, chart etc)
:param key: the key to retrieve
:param codec: codec used to decode the value
:return: the value associated with the key if present
"""
self.resource = resource
self.key = key
self.codec = codec
def run(self) -> Any:
try:
return self.get()
except SQLAlchemyError as ex:
raise KeyValueGetFailedError() from ex
def validate(self) -> None:
pass
def get(self) -> Optional[Any]:
filter_ = get_filter(self.resource, self.key)
entry = db.session.query(KeyValueEntry).filter_by(**filter_).first()
if entry and not entry.is_expired():
return self.codec.decode(entry.value)
return None

View File

@ -1,87 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional, Union
from uuid import UUID
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueUpdateFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.key_value.utils import get_filter
from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
class UpdateKeyValueCommand(BaseCommand):
resource: KeyValueResource
value: Any
codec: KeyValueCodec
key: Union[int, UUID]
expires_on: Optional[datetime]
def __init__( # pylint: disable=too-many-arguments
self,
resource: KeyValueResource,
key: Union[int, UUID],
value: Any,
codec: KeyValueCodec,
expires_on: Optional[datetime] = None,
):
"""
Update a key value entry
:param resource: the resource (dashboard, chart etc)
:param key: the key to update
:param value: the value to persist in the key-value store
:param codec: codec used to encode the value
:param expires_on: entry expiration time
:return: the key associated with the updated value
"""
self.resource = resource
self.key = key
self.value = value
self.codec = codec
self.expires_on = expires_on
@transaction(on_error=partial(on_error, reraise=KeyValueUpdateFailedError))
def run(self) -> Optional[Key]:
return self.update()
def validate(self) -> None:
pass
def update(self) -> Optional[Key]:
filter_ = get_filter(self.resource, self.key)
entry: KeyValueEntry = (
db.session.query(KeyValueEntry).filter_by(**filter_).first()
)
if entry:
entry.value = self.codec.encode(self.value)
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
db.session.flush()
return Key(id=entry.id, uuid=entry.uuid)
return None

View File

@ -1,104 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional, Union
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.key_value.create import CreateKeyValueCommand
from superset.key_value.exceptions import (
KeyValueCreateFailedError,
KeyValueUpsertFailedError,
)
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.key_value.utils import get_filter
from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
class UpsertKeyValueCommand(BaseCommand):
resource: KeyValueResource
value: Any
key: Union[int, UUID]
codec: KeyValueCodec
expires_on: Optional[datetime]
def __init__( # pylint: disable=too-many-arguments
self,
resource: KeyValueResource,
key: Union[int, UUID],
value: Any,
codec: KeyValueCodec,
expires_on: Optional[datetime] = None,
):
"""
Upsert a key value entry
:param resource: the resource (dashboard, chart etc)
:param key: the key to update
:param value: the value to persist in the key-value store
:param codec: codec used to encode the value
:param expires_on: entry expiration time
:return: the key associated with the updated value
"""
self.resource = resource
self.key = key
self.value = value
self.codec = codec
self.expires_on = expires_on
@transaction(
on_error=partial(
on_error,
catches=(KeyValueCreateFailedError, SQLAlchemyError),
reraise=KeyValueUpsertFailedError,
),
)
def run(self) -> Key:
return self.upsert()
def validate(self) -> None:
pass
def upsert(self) -> Key:
if (
entry := db.session.query(KeyValueEntry)
.filter_by(**get_filter(self.resource, self.key))
.first()
):
entry.value = self.codec.encode(self.value)
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
return Key(entry.id, entry.uuid)
return CreateKeyValueCommand(
resource=self.resource,
value=self.value,
codec=self.codec,
key=self.key,
expires_on=self.expires_on,
).run()

145
superset/daos/key_value.py Normal file
View File

@ -0,0 +1,145 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy import and_
from superset import db
from superset.daos.base import BaseDAO
from superset.key_value.exceptions import (
KeyValueCreateFailedError,
KeyValueUpdateFailedError,
)
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.key_value.utils import get_filter
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
class KeyValueDAO(BaseDAO[KeyValueEntry]):
@staticmethod
def get_entry(
resource: KeyValueResource,
key: Key,
) -> KeyValueEntry | None:
filter_ = get_filter(resource, key)
return db.session.query(KeyValueEntry).filter_by(**filter_).first()
@classmethod
def get_value(
cls,
resource: KeyValueResource,
key: Key,
codec: KeyValueCodec,
) -> Any:
entry = cls.get_entry(resource, key)
if not entry or entry.is_expired():
return None
return codec.decode(entry.value)
@staticmethod
def delete_entry(resource: KeyValueResource, key: Key) -> bool:
if entry := KeyValueDAO.get_entry(resource, key):
db.session.delete(entry)
return True
return False
@staticmethod
def delete_expired_entries(resource: KeyValueResource) -> None:
(
db.session.query(KeyValueEntry)
.filter(
and_(
KeyValueEntry.resource == resource.value,
KeyValueEntry.expires_on <= datetime.now(),
)
)
.delete()
)
@staticmethod
def create_entry(
resource: KeyValueResource,
value: Any,
codec: KeyValueCodec,
key: Key | None = None,
expires_on: datetime | None = None,
) -> KeyValueEntry:
try:
encoded_value = codec.encode(value)
except Exception as ex:
raise KeyValueCreateFailedError("Unable to encode value") from ex
entry = KeyValueEntry(
resource=resource.value,
value=encoded_value,
created_on=datetime.now(),
created_by_fk=get_user_id(),
expires_on=expires_on,
)
if key is not None:
try:
if isinstance(key, UUID):
entry.uuid = key
else:
entry.id = key
except ValueError as ex:
raise KeyValueCreateFailedError() from ex
db.session.add(entry)
return entry
@staticmethod
def upsert_entry(
resource: KeyValueResource,
value: Any,
codec: KeyValueCodec,
key: Key,
expires_on: datetime | None = None,
) -> KeyValueEntry:
if entry := KeyValueDAO.get_entry(resource, key):
entry.value = codec.encode(value)
entry.expires_on = expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
return entry
return KeyValueDAO.create_entry(resource, value, codec, key, expires_on)
@staticmethod
def update_entry(
resource: KeyValueResource,
value: Any,
codec: KeyValueCodec,
key: Key,
expires_on: datetime | None = None,
) -> KeyValueEntry:
if entry := KeyValueDAO.get_entry(resource, key):
entry.value = codec.encode(value)
entry.expires_on = expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
return entry
raise KeyValueUpdateFailedError()

View File

@ -21,40 +21,18 @@ import logging
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, cast, TypeVar, Union
from datetime import timedelta
from typing import Any
from superset.distributed_lock.utils import get_key
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
from superset.utils import json
LOCK_EXPIRATION = timedelta(seconds=30)
logger = logging.getLogger(__name__)
def serialize(params: dict[str, Any]) -> str:
"""
Serialize parameters into a string.
"""
T = TypeVar(
"T",
bound=Union[dict[str, Any], list[Any], int, float, str, bool, None],
)
def sort(obj: T) -> T:
if isinstance(obj, dict):
return cast(T, {k: sort(v) for k, v in sorted(obj.items())})
if isinstance(obj, list):
return cast(T, [sort(x) for x in obj])
return obj
return json.dumps(params)
def get_key(namespace: str, **kwargs: Any) -> uuid.UUID:
return uuid.uuid5(uuid.uuid5(uuid.NAMESPACE_DNS, namespace), serialize(kwargs))
CODEC = JsonKeyValueCodec()
LOCK_EXPIRATION = timedelta(seconds=30)
RESOURCE = KeyValueResource.LOCK
@contextmanager
@ -75,28 +53,25 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
:yields: A unique identifier (UUID) for the acquired lock (the KV key).
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
"""
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.create import CreateKeyValueCommand
from superset.commands.key_value.delete import DeleteKeyValueCommand
from superset.commands.key_value.delete_expired import DeleteExpiredKeyValueCommand
from superset.commands.distributed_lock.create import CreateDistributedLock
from superset.commands.distributed_lock.delete import DeleteDistributedLock
from superset.commands.distributed_lock.get import GetDistributedLock
key = get_key(namespace, **kwargs)
value = GetDistributedLock(namespace=namespace, params=kwargs).run()
if value:
logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
raise CreateKeyValueDistributedLockFailedException("Lock already taken")
logger.debug("Acquiring lock on namespace %s for key %s", namespace, key)
try:
DeleteExpiredKeyValueCommand(resource=KeyValueResource.LOCK).run()
CreateKeyValueCommand(
resource=KeyValueResource.LOCK,
codec=JsonKeyValueCodec(),
key=key,
value=True,
expires_on=datetime.now() + LOCK_EXPIRATION,
).run()
CreateDistributedLock(namespace=namespace, params=kwargs).run()
except CreateKeyValueDistributedLockFailedException as ex:
logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
raise CreateKeyValueDistributedLockFailedException("Lock already taken") from ex
yield key
DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run()
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
except KeyValueCreateFailedError as ex:
raise CreateKeyValueDistributedLockFailedException(
"Error acquiring lock"
) from ex
yield key
DeleteDistributedLock(namespace=namespace, params=kwargs).run()
logger.debug("Removed lock on namespace %s for key %s", namespace, key)

View File

@ -14,3 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TypedDict
class LockValue(TypedDict):
value: bool

View File

@ -14,3 +14,32 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import uuid
from typing import Any, cast, TypeVar, Union
from superset.utils import json
def serialize(params: dict[str, Any]) -> str:
"""
Serialize parameters into a string.
"""
T = TypeVar(
"T",
bound=Union[dict[str, Any], list[Any], int, float, str, bool, None],
)
def sort(obj: T) -> T:
if isinstance(obj, dict):
return cast(T, {k: sort(v) for k, v in sorted(obj.items())})
if isinstance(obj, list):
return cast(T, [sort(x) for x in obj])
return obj
return json.dumps(params)
def get_key(namespace: str, **kwargs: Any) -> uuid.UUID:
return uuid.uuid5(uuid.uuid5(uuid.NAMESPACE_DNS, namespace), serialize(kwargs))

View File

@ -379,6 +379,12 @@ class CreateKeyValueDistributedLockFailedException(Exception):
"""
class DeleteKeyValueDistributedLockFailedException(Exception):
"""
Exception to signalize failure to delete lock.
"""
class DatabaseNotFoundException(SupersetErrorException):
status = 404

View File

@ -21,7 +21,10 @@ from uuid import UUID, uuid3
from flask import current_app, Flask, has_app_context
from flask_caching import BaseCache
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.daos.key_value import KeyValueDAO
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import (
KeyValueCodec,
@ -29,6 +32,7 @@ from superset.key_value.types import (
PickleKeyValueCodec,
)
from superset.key_value.utils import get_uuid_namespace
from superset.utils.decorators import transaction
RESOURCE = KeyValueResource.METASTORE_CACHE
@ -68,15 +72,6 @@ class SupersetMetastoreCache(BaseCache):
def get_key(self, key: str) -> UUID:
return uuid3(self.namespace, key)
@staticmethod
def _prune() -> None:
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.delete_expired import (
DeleteExpiredKeyValueCommand,
)
DeleteExpiredKeyValueCommand(resource=RESOURCE).run()
def _get_expiry(self, timeout: Optional[int]) -> Optional[datetime]:
timeout = self._normalize_timeout(timeout)
if timeout is not None and timeout > 0:
@ -84,44 +79,34 @@ class SupersetMetastoreCache(BaseCache):
return None
def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.upsert import UpsertKeyValueCommand
UpsertKeyValueCommand(
KeyValueDAO.upsert_entry(
resource=RESOURCE,
key=self.get_key(key),
value=value,
codec=self.codec,
expires_on=self._get_expiry(timeout),
).run()
)
db.session.commit() # pylint: disable=consider-using-transaction
return True
def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.create import CreateKeyValueCommand
try:
self._prune()
CreateKeyValueCommand(
KeyValueDAO.delete_expired_entries(RESOURCE)
KeyValueDAO.create_entry(
resource=RESOURCE,
value=value,
codec=self.codec,
key=self.get_key(key),
expires_on=self._get_expiry(timeout),
).run()
)
db.session.commit() # pylint: disable=consider-using-transaction
return True
except KeyValueCreateFailedError:
except (SQLAlchemyError, KeyValueCreateFailedError):
db.session.rollback() # pylint: disable=consider-using-transaction
return False
def get(self, key: str) -> Any:
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.get import GetKeyValueCommand
return GetKeyValueCommand(
resource=RESOURCE,
key=self.get_key(key),
codec=self.codec,
).run()
return KeyValueDAO.get_value(RESOURCE, self.get_key(key), self.codec)
def has(self, key: str) -> bool:
entry = self.get(key)
@ -129,8 +114,6 @@ class SupersetMetastoreCache(BaseCache):
return True
return False
@transaction()
def delete(self, key: str) -> Any:
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.delete import DeleteKeyValueCommand
return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()
return KeyValueDAO.delete_entry(RESOURCE, self.get_key(key))

View File

@ -18,8 +18,10 @@
from typing import Any, Optional
from uuid import uuid3
from superset.daos.key_value import KeyValueDAO
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
from superset.key_value.utils import get_uuid_namespace, random_key
from superset.utils.decorators import transaction
RESOURCE = KeyValueResource.APP
NAMESPACE = get_uuid_namespace("")
@ -27,24 +29,14 @@ CODEC = JsonKeyValueCodec()
def get_shared_value(key: SharedKey) -> Optional[Any]:
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.get import GetKeyValueCommand
uuid_key = uuid3(NAMESPACE, key)
return GetKeyValueCommand(RESOURCE, key=uuid_key, codec=CODEC).run()
return KeyValueDAO.get_value(RESOURCE, uuid_key, CODEC)
@transaction()
def set_shared_value(key: SharedKey, value: Any) -> None:
# pylint: disable=import-outside-toplevel
from superset.commands.key_value.create import CreateKeyValueCommand
uuid_key = uuid3(NAMESPACE, key)
CreateKeyValueCommand(
resource=RESOURCE,
value=value,
key=uuid_key,
codec=CODEC,
).run()
KeyValueDAO.create_entry(RESOURCE, value, CODEC, uuid_key)
def get_permalink_salt(key: SharedKey) -> str:

View File

@ -19,8 +19,7 @@ from __future__ import annotations
import json
import pickle
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, TypedDict
from typing import Any, TypedDict, Union
from uuid import UUID
from marshmallow import Schema, ValidationError
@ -31,11 +30,7 @@ from superset.key_value.exceptions import (
)
from superset.utils.backports import StrEnum
@dataclass
class Key:
id: int | None
uuid: UUID | None
Key = Union[int, UUID]
class KeyValueFilter(TypedDict, total=False):

View File

@ -25,7 +25,7 @@ import hashids
from flask_babel import gettext as _
from superset.key_value.exceptions import KeyValueParseKeyError
from superset.key_value.types import KeyValueFilter, KeyValueResource
from superset.key_value.types import Key, KeyValueFilter, KeyValueResource
from superset.utils.json import json_dumps_w_dates
HASHIDS_MIN_LENGTH = 11
@ -35,7 +35,7 @@ def random_key() -> str:
return token_urlsafe(48)
def get_filter(resource: KeyValueResource, key: int | UUID) -> KeyValueFilter:
def get_filter(resource: KeyValueResource, key: Key) -> KeyValueFilter:
try:
filter_: KeyValueFilter = {"resource": resource.value}
if isinstance(key, UUID):

View File

@ -26,9 +26,9 @@ from flask import current_app, url_for
from marshmallow import EXCLUDE, fields, post_load, Schema
from superset import db
from superset.distributed_lock import KeyValueDistributedLock
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.superset_typing import OAuth2ClientConfig, OAuth2State
from superset.utils.lock import KeyValueDistributedLock
if TYPE_CHECKING:
from superset.db_engine_specs.base import BaseEngineSpec

View File

@ -133,11 +133,11 @@ class TestCreatePermalinkDataCommand(SupersetTestCase):
assert cache_data.get("datasource") == datasource
@patch("superset.security.manager.g")
@patch("superset.commands.key_value.get.GetKeyValueCommand.run")
@patch("superset.daos.key_value.KeyValueDAO.get_value")
@patch("superset.commands.explore.permalink.get.decode_permalink_id")
@pytest.mark.usefixtures("create_dataset", "create_slice")
def test_get_permalink_command_with_old_dataset_key(
self, decode_id_mock, get_kv_command_mock, mock_g
self, decode_id_mock, kv_get_value_mock, mock_g
):
mock_g.user = security_manager.find_user("admin")
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
@ -149,13 +149,14 @@ class TestCreatePermalinkDataCommand(SupersetTestCase):
)
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
datasource_string = f"{dataset.id}__{DatasourceType.TABLE}"
datasource_string = f"{dataset.id}__{DatasourceType.TABLE.value}"
decode_id_mock.return_value = "123456"
get_kv_command_mock.return_value = {
kv_get_value_mock.return_value = {
"chartId": slice.id,
"datasetId": dataset.id,
"datasource": datasource_string,
"datasourceType": DatasourceType.TABLE.value,
"state": {
"formData": {"datasource": datasource_string, "slice_id": slice.id}
},

View File

@ -60,6 +60,7 @@ def test_caching_flow(app_context: AppContext, cache: SupersetMetastoreCache) ->
assert cache.has(FIRST_KEY) is False
assert cache.add(FIRST_KEY, FIRST_KEY_INITIAL_VALUE) is True
assert cache.has(FIRST_KEY) is True
assert cache.get(FIRST_KEY) == FIRST_KEY_INITIAL_VALUE
cache.set(SECOND_KEY, SECOND_VALUE)
assert cache.get(FIRST_KEY) == FIRST_KEY_INITIAL_VALUE
assert cache.get(SECOND_KEY) == SECOND_VALUE

View File

@ -1,96 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import pickle
import pytest
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.utils import json
from superset.utils.core import override_user
from tests.integration_tests.key_value.commands.fixtures import (
admin, # noqa: F401
JSON_CODEC,
JSON_VALUE,
PICKLE_CODEC,
PICKLE_VALUE,
RESOURCE,
)
def test_create_id_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
from superset.commands.key_value.create import CreateKeyValueCommand
from superset.key_value.models import KeyValueEntry
with override_user(admin):
key = CreateKeyValueCommand(
resource=RESOURCE,
value=JSON_VALUE,
codec=JSON_CODEC,
).run()
entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
assert json.loads(entry.value) == JSON_VALUE
assert entry.created_by_fk == admin.id
db.session.delete(entry)
db.session.commit()
def test_create_uuid_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
from superset.commands.key_value.create import CreateKeyValueCommand
from superset.key_value.models import KeyValueEntry
with override_user(admin):
key = CreateKeyValueCommand(
resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC
).run()
entry = db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).one()
assert json.loads(entry.value) == JSON_VALUE
assert entry.created_by_fk == admin.id
db.session.delete(entry)
db.session.commit()
def test_create_fail_json_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
from superset.commands.key_value.create import CreateKeyValueCommand
with pytest.raises(KeyValueCreateFailedError):
CreateKeyValueCommand(
resource=RESOURCE,
value=PICKLE_VALUE,
codec=JSON_CODEC,
).run()
def test_create_pickle_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
from superset.commands.key_value.create import CreateKeyValueCommand
from superset.key_value.models import KeyValueEntry
with override_user(admin):
key = CreateKeyValueCommand(
resource=RESOURCE,
value=PICKLE_VALUE,
codec=PICKLE_CODEC,
).run()
entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
assert type(pickle.loads(entry.value)) == type(PICKLE_VALUE)
assert entry.created_by_fk == admin.id
db.session.delete(entry)
db.session.commit()

View File

@ -1,84 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from uuid import UUID
import pytest
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
from superset.utils import json
from tests.integration_tests.key_value.commands.fixtures import (
admin, # noqa: F401
JSON_VALUE,
RESOURCE,
)
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
ID_KEY = 234
UUID_KEY = UUID("5aae143c-44f1-478e-9153-ae6154df333a")
@pytest.fixture
def key_value_entry() -> KeyValueEntry:
from superset.key_value.models import KeyValueEntry
entry = KeyValueEntry(
id=ID_KEY,
uuid=UUID_KEY,
resource=RESOURCE,
value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"),
)
db.session.add(entry)
db.session.flush()
return entry
def test_delete_id_entry(
app_context: AppContext,
admin: User, # noqa: F811
key_value_entry: KeyValueEntry,
) -> None:
from superset.commands.key_value.delete import DeleteKeyValueCommand
assert DeleteKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() is True
db.session.commit()
def test_delete_uuid_entry(
app_context: AppContext,
admin: User, # noqa: F811
key_value_entry: KeyValueEntry,
) -> None:
from superset.commands.key_value.delete import DeleteKeyValueCommand
assert DeleteKeyValueCommand(resource=RESOURCE, key=UUID_KEY).run() is True
db.session.commit()
def test_delete_entry_missing(
app_context: AppContext,
admin: User, # noqa: F811
) -> None:
from superset.commands.key_value.delete import DeleteKeyValueCommand
assert DeleteKeyValueCommand(resource=RESOURCE, key=456).run() is False

View File

@ -1,69 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections.abc import Generator
from typing import TYPE_CHECKING
from uuid import UUID
import pytest
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
from superset.key_value.types import (
JsonKeyValueCodec,
KeyValueResource,
PickleKeyValueCodec,
)
from superset.utils import json
from tests.integration_tests.test_app import app
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
ID_KEY = 123
UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc")
RESOURCE = KeyValueResource.APP
JSON_VALUE = {"foo": "bar"}
PICKLE_VALUE = object()
JSON_CODEC = JsonKeyValueCodec()
PICKLE_CODEC = PickleKeyValueCodec()
@pytest.fixture
def key_value_entry() -> Generator[KeyValueEntry, None, None]:
from superset.key_value.models import KeyValueEntry
entry = KeyValueEntry(
id=ID_KEY,
uuid=UUID_KEY,
resource=RESOURCE,
value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"),
)
db.session.add(entry)
db.session.flush()
yield entry
db.session.delete(entry)
db.session.commit()
@pytest.fixture
def admin() -> User:
with app.app_context(): # noqa: F841
admin = db.session.query(User).filter_by(username="admin").one()
return admin

View File

@ -1,103 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import uuid
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from flask.ctx import AppContext
from superset.extensions import db
from superset.utils import json
from tests.integration_tests.key_value.commands.fixtures import (
ID_KEY,
JSON_CODEC,
JSON_VALUE,
key_value_entry, # noqa: F401
RESOURCE,
UUID_KEY,
)
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
def test_get_id_entry(app_context: AppContext, key_value_entry: KeyValueEntry) -> None: # noqa: F811
from superset.commands.key_value.get import GetKeyValueCommand
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, codec=JSON_CODEC).run()
assert value == JSON_VALUE
def test_get_uuid_entry(
app_context: AppContext,
key_value_entry: KeyValueEntry, # noqa: F811
) -> None:
from superset.commands.key_value.get import GetKeyValueCommand
value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY, codec=JSON_CODEC).run()
assert value == JSON_VALUE
def test_get_id_entry_missing(
app_context: AppContext,
key_value_entry: KeyValueEntry, # noqa: F811
) -> None:
from superset.commands.key_value.get import GetKeyValueCommand
value = GetKeyValueCommand(resource=RESOURCE, key=456, codec=JSON_CODEC).run()
assert value is None
def test_get_expired_entry(app_context: AppContext) -> None:
from superset.commands.key_value.get import GetKeyValueCommand
from superset.key_value.models import KeyValueEntry
entry = KeyValueEntry(
id=678,
uuid=uuid.uuid4(),
resource=RESOURCE,
value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"),
expires_on=datetime.now() - timedelta(days=1),
)
db.session.add(entry)
db.session.flush()
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, codec=JSON_CODEC).run()
assert value is None
db.session.delete(entry)
db.session.commit()
def test_get_future_expiring_entry(app_context: AppContext) -> None:
from superset.commands.key_value.get import GetKeyValueCommand
from superset.key_value.models import KeyValueEntry
id_ = 789
entry = KeyValueEntry(
id=id_,
uuid=uuid.uuid4(),
resource=RESOURCE,
value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"),
expires_on=datetime.now() + timedelta(days=1),
)
db.session.add(entry)
db.session.flush()
value = GetKeyValueCommand(resource=RESOURCE, key=id_, codec=JSON_CODEC).run()
assert value == JSON_VALUE
db.session.delete(entry)
db.session.commit()

View File

@ -1,97 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
from superset.utils import json
from superset.utils.core import override_user
from tests.integration_tests.key_value.commands.fixtures import (
admin, # noqa: F401
ID_KEY,
JSON_CODEC,
key_value_entry, # noqa: F401
RESOURCE,
UUID_KEY,
)
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
NEW_VALUE = "new value"
def test_update_id_entry(
app_context: AppContext,
admin: User, # noqa: F811
key_value_entry: KeyValueEntry, # noqa: F811
) -> None:
from superset.commands.key_value.update import UpdateKeyValueCommand
from superset.key_value.models import KeyValueEntry
with override_user(admin):
key = UpdateKeyValueCommand(
resource=RESOURCE,
key=ID_KEY,
value=NEW_VALUE,
codec=JSON_CODEC,
).run()
assert key is not None
assert key.id == ID_KEY
entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
def test_update_uuid_entry(
app_context: AppContext,
admin: User, # noqa: F811
key_value_entry: KeyValueEntry, # noqa: F811
) -> None:
from superset.commands.key_value.update import UpdateKeyValueCommand
from superset.key_value.models import KeyValueEntry
with override_user(admin):
key = UpdateKeyValueCommand(
resource=RESOURCE,
key=UUID_KEY,
value=NEW_VALUE,
codec=JSON_CODEC,
).run()
assert key is not None
assert key.uuid == UUID_KEY
entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
def test_update_missing_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
from superset.commands.key_value.update import UpdateKeyValueCommand
with override_user(admin):
key = UpdateKeyValueCommand(
resource=RESOURCE,
key=456,
value=NEW_VALUE,
codec=JSON_CODEC,
).run()
assert key is None

View File

@ -1,101 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
from superset.utils import json
from superset.utils.core import override_user
from tests.integration_tests.key_value.commands.fixtures import (
admin, # noqa: F401
ID_KEY,
JSON_CODEC,
key_value_entry, # noqa: F401
RESOURCE,
UUID_KEY,
)
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
NEW_VALUE = "new value"
def test_upsert_id_entry(
app_context: AppContext,
admin: User, # noqa: F811
key_value_entry: KeyValueEntry, # noqa: F811
) -> None:
from superset.commands.key_value.upsert import UpsertKeyValueCommand
from superset.key_value.models import KeyValueEntry
with override_user(admin):
key = UpsertKeyValueCommand(
resource=RESOURCE,
key=ID_KEY,
value=NEW_VALUE,
codec=JSON_CODEC,
).run()
assert key is not None
assert key.id == ID_KEY
entry = db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
def test_upsert_uuid_entry(
app_context: AppContext,
admin: User, # noqa: F811
key_value_entry: KeyValueEntry, # noqa: F811
) -> None:
from superset.commands.key_value.upsert import UpsertKeyValueCommand
from superset.key_value.models import KeyValueEntry
with override_user(admin):
key = UpsertKeyValueCommand(
resource=RESOURCE,
key=UUID_KEY,
value=NEW_VALUE,
codec=JSON_CODEC,
).run()
assert key is not None
assert key.uuid == UUID_KEY
entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
assert json.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
def test_upsert_missing_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
from superset.commands.key_value.upsert import UpsertKeyValueCommand
from superset.key_value.models import KeyValueEntry
with override_user(admin):
key = UpsertKeyValueCommand(
resource=RESOURCE,
key=456,
value=NEW_VALUE,
codec=JSON_CODEC,
).run()
assert key is not None
assert key.id == 456
db.session.query(KeyValueEntry).filter_by(id=456).delete()
db.session.commit()

View File

@ -0,0 +1,395 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, unused-import
from __future__ import annotations
import pickle
from datetime import datetime, timedelta
from typing import Generator, TYPE_CHECKING
from uuid import UUID
import pytest
from flask.ctx import AppContext
from flask_appbuilder.security.sqla.models import User
from superset.extensions import db
from superset.key_value.exceptions import (
KeyValueCreateFailedError,
KeyValueUpdateFailedError,
)
from superset.key_value.types import (
JsonKeyValueCodec,
KeyValueResource,
PickleKeyValueCodec,
)
from superset.utils import json
from superset.utils.core import override_user
from tests.unit_tests.fixtures.common import admin_user, after_each # noqa: F401
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
ID_KEY = 123
UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc")
RESOURCE = KeyValueResource.APP
JSON_VALUE = {"foo": "bar"}
PICKLE_VALUE = object()
JSON_CODEC = JsonKeyValueCodec()
PICKLE_CODEC = PickleKeyValueCodec()
NEW_VALUE = {"foo": "baz"}
@pytest.fixture
def key_value_entry() -> Generator[KeyValueEntry, None, None]:
from superset.key_value.models import KeyValueEntry
entry = KeyValueEntry(
id=ID_KEY,
uuid=UUID_KEY,
resource=RESOURCE,
value=JSON_CODEC.encode(JSON_VALUE),
)
db.session.add(entry)
db.session.flush()
yield entry
def test_create_id_entry(
app_context: AppContext,
admin_user: User, # noqa: F811
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
from superset.key_value.models import KeyValueEntry
with override_user(admin_user):
created_entry = KeyValueDAO.create_entry(
resource=RESOURCE,
value=JSON_VALUE,
codec=JSON_CODEC,
)
db.session.flush()
found_entry = (
db.session.query(KeyValueEntry).filter_by(id=created_entry.id).one()
)
assert json.loads(found_entry.value) == JSON_VALUE
assert found_entry.created_by_fk == admin_user.id
def test_create_uuid_entry(
app_context: AppContext,
admin_user: User, # noqa: F811
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
from superset.key_value.models import KeyValueEntry
with override_user(admin_user):
created_entry = KeyValueDAO.create_entry(
resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC
)
db.session.flush()
found_entry = (
db.session.query(KeyValueEntry).filter_by(uuid=created_entry.uuid).one()
)
assert json.loads(found_entry.value) == JSON_VALUE
assert found_entry.created_by_fk == admin_user.id
def test_create_fail_json_entry(
app_context: AppContext,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
with pytest.raises(KeyValueCreateFailedError):
KeyValueDAO.create_entry(
resource=RESOURCE,
value=PICKLE_VALUE,
codec=JSON_CODEC,
)
def test_create_pickle_entry(
app_context: AppContext,
admin_user: User, # noqa: F811
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
from superset.key_value.models import KeyValueEntry
with override_user(admin_user):
created_entry = KeyValueDAO.create_entry(
resource=RESOURCE,
value=PICKLE_VALUE,
codec=PICKLE_CODEC,
)
db.session.flush()
found_entry = (
db.session.query(KeyValueEntry).filter_by(id=created_entry.id).one()
)
assert type(pickle.loads(found_entry.value)) == type(PICKLE_VALUE)
assert found_entry.created_by_fk == admin_user.id
def test_get_value(
app_context: AppContext,
key_value_entry: KeyValueEntry,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
value = KeyValueDAO.get_value(
resource=RESOURCE,
key=key_value_entry.id,
codec=JSON_CODEC,
)
assert value == JSON_VALUE
def test_get_id_entry(
app_context: AppContext,
key_value_entry: KeyValueEntry,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=key_value_entry.id)
assert found_entry is not None
assert found_entry.id == key_value_entry.id
def test_get_uuid_entry(
app_context: AppContext,
key_value_entry: KeyValueEntry, # noqa: F811
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=key_value_entry.uuid)
assert found_entry is not None
assert JSON_CODEC.decode(found_entry.value) == JSON_VALUE
def test_get_id_entry_missing(
app_context: AppContext,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
entry = KeyValueDAO.get_entry(resource=RESOURCE, key=456)
assert entry is None
def test_get_expired_entry(
app_context: AppContext,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
created_entry = KeyValueDAO.create_entry(
resource=RESOURCE,
value=JSON_VALUE,
codec=JSON_CODEC,
key=ID_KEY,
expires_on=datetime.now() - timedelta(days=1),
)
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=created_entry.id)
assert found_entry is not None
assert found_entry.is_expired() is True
def test_get_future_expiring_entry(
app_context: AppContext,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
created_entry = KeyValueDAO.create_entry(
resource=RESOURCE,
value=JSON_VALUE,
codec=JSON_CODEC,
key=ID_KEY,
expires_on=datetime.now() + timedelta(days=1),
)
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=created_entry.id)
assert found_entry is not None
assert found_entry.is_expired() is False
def test_update_id_entry(
app_context: AppContext,
key_value_entry: KeyValueEntry, # noqa: F811
admin_user: User, # noqa: F811
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
with override_user(admin_user):
updated_entry = KeyValueDAO.update_entry(
resource=RESOURCE,
key=ID_KEY,
value=NEW_VALUE,
codec=JSON_CODEC,
)
db.session.flush()
assert updated_entry is not None
assert JSON_CODEC.decode(updated_entry.value) == NEW_VALUE
assert updated_entry.id == ID_KEY
assert updated_entry.uuid == UUID_KEY
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY)
assert found_entry is not None
assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE
assert found_entry.changed_by_fk == admin_user.id
def test_update_uuid_entry(
app_context: AppContext,
key_value_entry: KeyValueEntry, # noqa: F811
admin_user: User, # noqa: F811
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
with override_user(admin_user):
updated_entry = KeyValueDAO.update_entry(
resource=RESOURCE,
key=UUID_KEY,
value=NEW_VALUE,
codec=JSON_CODEC,
)
db.session.flush()
assert updated_entry is not None
assert JSON_CODEC.decode(updated_entry.value) == NEW_VALUE
assert updated_entry.id == ID_KEY
assert updated_entry.uuid == UUID_KEY
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=UUID_KEY)
assert found_entry is not None
assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE
assert found_entry.changed_by_fk == admin_user.id
def test_update_missing_entry(
app_context: AppContext,
admin_user: User, # noqa: F811
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
with override_user(admin_user):
with pytest.raises(KeyValueUpdateFailedError):
KeyValueDAO.update_entry(
resource=RESOURCE,
key=456,
value=NEW_VALUE,
codec=JSON_CODEC,
)
def test_upsert_id_entry(
app_context: AppContext,
key_value_entry: KeyValueEntry, # noqa: F811
admin_user: User, # noqa: F811
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
with override_user(admin_user):
entry = KeyValueDAO.upsert_entry(
resource=RESOURCE,
key=ID_KEY,
value=NEW_VALUE,
codec=JSON_CODEC,
)
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY)
assert found_entry is not None
assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin_user.id
def test_upsert_uuid_entry(
app_context: AppContext,
key_value_entry: KeyValueEntry, # noqa: F811
admin_user: User, # noqa: F811
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
with override_user(admin_user):
entry = KeyValueDAO.upsert_entry(
resource=RESOURCE,
key=UUID_KEY,
value=NEW_VALUE,
codec=JSON_CODEC,
)
db.session.flush()
assert entry is not None
assert entry.id == ID_KEY
assert entry.uuid == UUID_KEY
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=UUID_KEY)
assert found_entry is not None
assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin_user.id
def test_upsert_missing_entry(
app_context: AppContext,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY)
assert entry is None
KeyValueDAO.upsert_entry(
resource=RESOURCE,
key=ID_KEY,
value=NEW_VALUE,
codec=JSON_CODEC,
)
entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY)
assert entry is not None
assert JSON_CODEC.decode(entry.value) == NEW_VALUE
def test_delete_id_entry(
app_context: AppContext,
key_value_entry: KeyValueEntry,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
assert KeyValueDAO.delete_entry(resource=RESOURCE, key=ID_KEY) is True
def test_delete_uuid_entry(
app_context: AppContext,
key_value_entry: KeyValueEntry,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
assert KeyValueDAO.delete_entry(resource=RESOURCE, key=UUID_KEY) is True
def test_delete_entry_missing(
app_context: AppContext,
after_each: None, # noqa: F811
) -> None:
from superset.daos.key_value import KeyValueDAO
assert KeyValueDAO.delete_entry(resource=RESOURCE, key=12345678) is False

View File

@ -22,17 +22,21 @@ from uuid import UUID
import pytest
from freezegun import freeze_time
from sqlalchemy.orm import Session, sessionmaker
from superset import db
from superset.distributed_lock import KeyValueDistributedLock
from superset.distributed_lock.types import LockValue
from superset.distributed_lock.utils import get_key
from superset.exceptions import CreateKeyValueDistributedLockFailedException
from superset.key_value.types import JsonKeyValueCodec
from superset.utils.lock import get_key, KeyValueDistributedLock
LOCK_VALUE: LockValue = {"value": True}
MAIN_KEY = get_key("ns", a=1, b=2)
OTHER_KEY = get_key("ns2", a=1, b=2)
def _get_lock(key: UUID) -> Any:
def _get_lock(key: UUID, session: Session) -> Any:
from superset.key_value.models import KeyValueEntry
entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first()
@ -42,41 +46,56 @@ def _get_lock(key: UUID) -> Any:
return JsonKeyValueCodec().decode(entry.value)
def _get_other_session() -> Session:
# This session is used to simulate what another worker will find in the metastore
# during the locking process.
from superset import db
bind = db.session.get_bind()
SessionMaker = sessionmaker(bind=bind)
return SessionMaker()
def test_key_value_distributed_lock_happy_path() -> None:
"""
Test successfully acquiring and returning the distributed lock.
Note we use a nested transaction to ensure that the cleanup from the outer context
manager is correctly invoked, otherwise a partial rollback would occur leaving the
database in a fractured state.
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
the assertions would only be using the non-committed state from the main session.
"""
session = _get_other_session()
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY) is None
assert _get_lock(MAIN_KEY, session) is None
with KeyValueDistributedLock("ns", a=1, b=2) as key:
assert key == MAIN_KEY
assert _get_lock(key) is True
assert _get_lock(OTHER_KEY) is None
assert _get_lock(key, session) == LOCK_VALUE
assert _get_lock(OTHER_KEY, session) is None
with db.session.begin_nested():
with pytest.raises(CreateKeyValueDistributedLockFailedException):
with KeyValueDistributedLock("ns", a=1, b=2):
pass
with pytest.raises(CreateKeyValueDistributedLockFailedException):
with KeyValueDistributedLock("ns", a=1, b=2):
pass
assert _get_lock(MAIN_KEY) is None
assert _get_lock(MAIN_KEY, session) is None
def test_key_value_distributed_lock_expired() -> None:
"""
Test expiration of the distributed lock
Note, we're using another session for asserting the lock state in the Metastore
to simulate what another worker will observe. Otherwise, there's the risk that
the assertions would only be using the non-committed state from the main session.
"""
session = _get_other_session()
with freeze_time("2021-01-01"):
assert _get_lock(MAIN_KEY) is None
assert _get_lock(MAIN_KEY, session) is None
with KeyValueDistributedLock("ns", a=1, b=2):
assert _get_lock(MAIN_KEY) is True
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
with freeze_time("2022-01-01"):
assert _get_lock(MAIN_KEY) is None
assert _get_lock(MAIN_KEY, session) is None
assert _get_lock(MAIN_KEY) is None
assert _get_lock(MAIN_KEY, session) is None

View File

@ -20,12 +20,15 @@ from __future__ import annotations
import csv
from datetime import datetime
from io import BytesIO, StringIO
from typing import Any
from typing import Any, Generator
import pandas as pd
import pytest
from flask_appbuilder.security.sqla.models import Role, User
from werkzeug.datastructures import FileStorage
from superset import db
@pytest.fixture
def dttm() -> datetime:
@ -73,3 +76,24 @@ def create_columnar_file(
df.to_parquet(buffer, index=False)
buffer.seek(0)
return FileStorage(stream=buffer, filename=filename)
@pytest.fixture
def admin_user() -> Generator[User, None, None]:
role = db.session.query(Role).filter_by(name="Admin").one()
user = User(
first_name="Alice",
last_name="Admin",
email="alice_admin@example.org",
username="alice_admin",
roles=[role],
)
db.session.add(user)
db.session.flush()
yield user
@pytest.fixture
def after_each() -> Generator[None, None, None]:
yield
db.session.rollback()