chore(key-value): convert command to dao (#29344)
This commit is contained in:
parent
0cf676b574
commit
7d6e933348
|
|
@ -19,9 +19,10 @@ from functools import partial
|
|||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
|
||||
from superset.commands.key_value.upsert import UpsertKeyValueCommand
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError
|
||||
from superset.dashboards.permalink.types import DashboardPermalinkState
|
||||
from superset.key_value.exceptions import (
|
||||
|
|
@ -70,14 +71,15 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
|
|||
"state": self.state,
|
||||
}
|
||||
user_id = get_user_id()
|
||||
key = UpsertKeyValueCommand(
|
||||
entry = KeyValueDAO.upsert_entry(
|
||||
resource=self.resource,
|
||||
key=get_deterministic_uuid(self.salt, (user_id, value)),
|
||||
value=value,
|
||||
codec=self.codec,
|
||||
).run()
|
||||
assert key.id # for type checks
|
||||
return encode_permalink_key(key=key.id, salt=self.salt)
|
||||
)
|
||||
db.session.flush()
|
||||
assert entry.id # for type checks
|
||||
return encode_permalink_key(key=entry.id, salt=self.salt)
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
|
||||
from superset.commands.dashboard.exceptions import DashboardNotFoundError
|
||||
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
|
||||
from superset.commands.key_value.get import GetKeyValueCommand
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError
|
||||
from superset.dashboards.permalink.types import DashboardPermalinkValue
|
||||
from superset.key_value.exceptions import (
|
||||
|
|
@ -43,12 +43,7 @@ class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
|
|||
self.validate()
|
||||
try:
|
||||
key = decode_permalink_id(self.key, salt=self.salt)
|
||||
command = GetKeyValueCommand(
|
||||
resource=self.resource,
|
||||
key=key,
|
||||
codec=self.codec,
|
||||
)
|
||||
value: Optional[DashboardPermalinkValue] = command.run()
|
||||
value = KeyValueDAO.get_value(self.resource, key, self.codec)
|
||||
if value:
|
||||
DashboardDAO.get_by_id_or_slug(value["dashboardId"])
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -14,3 +14,28 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.distributed_lock.utils import get_key
|
||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = current_app.config["STATS_LOGGER"]
|
||||
|
||||
|
||||
class BaseDistributedLockCommand(BaseCommand):
|
||||
key: uuid.UUID
|
||||
codec = JsonKeyValueCodec()
|
||||
resource = KeyValueResource.LOCK
|
||||
|
||||
def __init__(self, namespace: str, params: dict[str, Any] | None = None):
|
||||
self.key = get_key(namespace, **(params or {}))
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from functools import partial
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.key_value.exceptions import (
|
||||
KeyValueCodecEncodeException,
|
||||
KeyValueUpsertFailedError,
|
||||
)
|
||||
from superset.key_value.types import KeyValueResource
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = current_app.config["STATS_LOGGER"]
|
||||
|
||||
|
||||
class CreateDistributedLock(BaseDistributedLockCommand):
|
||||
lock_expiration = timedelta(seconds=30)
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(
|
||||
KeyValueCodecEncodeException,
|
||||
KeyValueUpsertFailedError,
|
||||
SQLAlchemyError,
|
||||
),
|
||||
reraise=CreateKeyValueDistributedLockFailedException,
|
||||
),
|
||||
)
|
||||
def run(self) -> None:
|
||||
KeyValueDAO.delete_expired_entries(self.resource)
|
||||
KeyValueDAO.create_entry(
|
||||
resource=KeyValueResource.LOCK,
|
||||
value={"value": True},
|
||||
codec=self.codec,
|
||||
key=self.key,
|
||||
expires_on=datetime.now() + self.lock_expiration,
|
||||
)
|
||||
|
|
@ -14,49 +14,36 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
|
||||
from sqlalchemy import and_
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.exceptions import DeleteKeyValueDistributedLockFailedException
|
||||
from superset.key_value.exceptions import KeyValueDeleteFailedError
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
from superset.key_value.types import KeyValueResource
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = current_app.config["STATS_LOGGER"]
|
||||
|
||||
|
||||
class DeleteExpiredKeyValueCommand(BaseCommand):
|
||||
resource: KeyValueResource
|
||||
|
||||
def __init__(self, resource: KeyValueResource):
|
||||
"""
|
||||
Delete all expired key-value pairs
|
||||
|
||||
:param resource: the resource (dashboard, chart etc)
|
||||
:return: was the entry deleted or not
|
||||
"""
|
||||
self.resource = resource
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError))
|
||||
def run(self) -> None:
|
||||
self.delete_expired()
|
||||
|
||||
class DeleteDistributedLock(BaseDistributedLockCommand):
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
def delete_expired(self) -> None:
|
||||
(
|
||||
db.session.query(KeyValueEntry)
|
||||
.filter(
|
||||
and_(
|
||||
KeyValueEntry.resource == self.resource.value,
|
||||
KeyValueEntry.expires_on <= datetime.now(),
|
||||
)
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(
|
||||
KeyValueDeleteFailedError,
|
||||
SQLAlchemyError,
|
||||
),
|
||||
reraise=DeleteKeyValueDistributedLockFailedException,
|
||||
),
|
||||
)
|
||||
def run(self) -> None:
|
||||
KeyValueDAO.delete_entry(self.resource, self.key)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from superset.commands.distributed_lock.base import BaseDistributedLockCommand
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.distributed_lock.types import LockValue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
stats_logger = current_app.config["STATS_LOGGER"]
|
||||
|
||||
|
||||
class GetDistributedLock(BaseDistributedLockCommand):
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
def run(self) -> LockValue | None:
|
||||
entry = KeyValueDAO.get_entry(
|
||||
resource=self.resource,
|
||||
key=self.key,
|
||||
)
|
||||
if not entry or entry.is_expired():
|
||||
return None
|
||||
|
||||
return cast(LockValue, self.codec.decode(entry.value))
|
||||
|
|
@ -20,8 +20,9 @@ from typing import Any, Optional
|
|||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
from superset.key_value.exceptions import (
|
||||
|
|
@ -65,15 +66,12 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
|
|||
"datasource": self.datasource,
|
||||
"state": self.state,
|
||||
}
|
||||
command = CreateKeyValueCommand(
|
||||
resource=self.resource,
|
||||
value=value,
|
||||
codec=self.codec,
|
||||
)
|
||||
key = command.run()
|
||||
if key.id is None:
|
||||
entry = KeyValueDAO.create_entry(self.resource, value, self.codec)
|
||||
db.session.flush()
|
||||
key = entry.id
|
||||
if key is None:
|
||||
raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
|
||||
return encode_permalink_key(key=key.id, salt=self.salt)
|
||||
return encode_permalink_key(key=key, salt=self.salt)
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
|||
|
||||
from superset.commands.dataset.exceptions import DatasetNotFoundError
|
||||
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
|
||||
from superset.commands.key_value.get import GetKeyValueCommand
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
|
||||
from superset.explore.permalink.types import ExplorePermalinkValue
|
||||
from superset.explore.utils import check_access as check_chart_access
|
||||
|
|
@ -44,11 +44,7 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
|
|||
self.validate()
|
||||
try:
|
||||
key = decode_permalink_id(self.key, salt=self.salt)
|
||||
value: Optional[ExplorePermalinkValue] = GetKeyValueCommand(
|
||||
resource=self.resource,
|
||||
key=key,
|
||||
codec=self.codec,
|
||||
).run()
|
||||
value = KeyValueDAO.get_value(self.resource, key, self.codec)
|
||||
if value:
|
||||
chart_id: Optional[int] = value.get("chartId")
|
||||
# keep this backward compatible for old permalinks
|
||||
|
|
|
|||
|
|
@ -1,102 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
|
||||
from superset.utils.core import get_user_id
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateKeyValueCommand(BaseCommand):
|
||||
resource: KeyValueResource
|
||||
value: Any
|
||||
codec: KeyValueCodec
|
||||
key: Optional[Union[int, UUID]]
|
||||
expires_on: Optional[datetime]
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
resource: KeyValueResource,
|
||||
value: Any,
|
||||
codec: KeyValueCodec,
|
||||
key: Optional[Union[int, UUID]] = None,
|
||||
expires_on: Optional[datetime] = None,
|
||||
):
|
||||
"""
|
||||
Create a new key-value pair
|
||||
|
||||
:param resource: the resource (dashboard, chart etc)
|
||||
:param value: the value to persist in the key-value store
|
||||
:param codec: codec used to encode the value
|
||||
:param key: id of entry (autogenerated if undefined)
|
||||
:param expires_on: entry expiration time
|
||||
:
|
||||
"""
|
||||
self.resource = resource
|
||||
self.value = value
|
||||
self.codec = codec
|
||||
self.key = key
|
||||
self.expires_on = expires_on
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=KeyValueCreateFailedError))
|
||||
def run(self) -> Key:
|
||||
"""
|
||||
Persist the value
|
||||
|
||||
:return: the key associated with the persisted value
|
||||
|
||||
"""
|
||||
|
||||
return self.create()
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
def create(self) -> Key:
|
||||
try:
|
||||
value = self.codec.encode(self.value)
|
||||
except Exception as ex:
|
||||
raise KeyValueCreateFailedError("Unable to encode value") from ex
|
||||
entry = KeyValueEntry(
|
||||
resource=self.resource.value,
|
||||
value=value,
|
||||
created_on=datetime.now(),
|
||||
created_by_fk=get_user_id(),
|
||||
expires_on=self.expires_on,
|
||||
)
|
||||
if self.key is not None:
|
||||
try:
|
||||
if isinstance(self.key, UUID):
|
||||
entry.uuid = self.key
|
||||
else:
|
||||
entry.id = self.key
|
||||
except ValueError as ex:
|
||||
raise KeyValueCreateFailedError() from ex
|
||||
|
||||
db.session.add(entry)
|
||||
db.session.flush()
|
||||
return Key(id=entry.id, uuid=entry.uuid)
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Union
|
||||
from uuid import UUID
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.key_value.exceptions import KeyValueDeleteFailedError
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
from superset.key_value.types import KeyValueResource
|
||||
from superset.key_value.utils import get_filter
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeleteKeyValueCommand(BaseCommand):
|
||||
key: Union[int, UUID]
|
||||
resource: KeyValueResource
|
||||
|
||||
def __init__(self, resource: KeyValueResource, key: Union[int, UUID]):
|
||||
"""
|
||||
Delete a key-value pair
|
||||
|
||||
:param resource: the resource (dashboard, chart etc)
|
||||
:param key: the key to delete
|
||||
:return: was the entry deleted or not
|
||||
"""
|
||||
self.resource = resource
|
||||
self.key = key
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError))
|
||||
def run(self) -> bool:
|
||||
return self.delete()
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
def delete(self) -> bool:
|
||||
if (
|
||||
entry := db.session.query(KeyValueEntry)
|
||||
.filter_by(**get_filter(self.resource, self.key))
|
||||
.first()
|
||||
):
|
||||
db.session.delete(entry)
|
||||
return True
|
||||
return False
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.key_value.exceptions import KeyValueGetFailedError
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
from superset.key_value.types import KeyValueCodec, KeyValueResource
|
||||
from superset.key_value.utils import get_filter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetKeyValueCommand(BaseCommand):
|
||||
resource: KeyValueResource
|
||||
key: Union[int, UUID]
|
||||
codec: KeyValueCodec
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resource: KeyValueResource,
|
||||
key: Union[int, UUID],
|
||||
codec: KeyValueCodec,
|
||||
):
|
||||
"""
|
||||
Retrieve a key value entry
|
||||
|
||||
:param resource: the resource (dashboard, chart etc)
|
||||
:param key: the key to retrieve
|
||||
:param codec: codec used to decode the value
|
||||
:return: the value associated with the key if present
|
||||
"""
|
||||
self.resource = resource
|
||||
self.key = key
|
||||
self.codec = codec
|
||||
|
||||
def run(self) -> Any:
|
||||
try:
|
||||
return self.get()
|
||||
except SQLAlchemyError as ex:
|
||||
raise KeyValueGetFailedError() from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
def get(self) -> Optional[Any]:
|
||||
filter_ = get_filter(self.resource, self.key)
|
||||
entry = db.session.query(KeyValueEntry).filter_by(**filter_).first()
|
||||
if entry and not entry.is_expired():
|
||||
return self.codec.decode(entry.value)
|
||||
return None
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.key_value.exceptions import KeyValueUpdateFailedError
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
|
||||
from superset.key_value.utils import get_filter
|
||||
from superset.utils.core import get_user_id
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateKeyValueCommand(BaseCommand):
|
||||
resource: KeyValueResource
|
||||
value: Any
|
||||
codec: KeyValueCodec
|
||||
key: Union[int, UUID]
|
||||
expires_on: Optional[datetime]
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
resource: KeyValueResource,
|
||||
key: Union[int, UUID],
|
||||
value: Any,
|
||||
codec: KeyValueCodec,
|
||||
expires_on: Optional[datetime] = None,
|
||||
):
|
||||
"""
|
||||
Update a key value entry
|
||||
|
||||
:param resource: the resource (dashboard, chart etc)
|
||||
:param key: the key to update
|
||||
:param value: the value to persist in the key-value store
|
||||
:param codec: codec used to encode the value
|
||||
:param expires_on: entry expiration time
|
||||
:return: the key associated with the updated value
|
||||
"""
|
||||
self.resource = resource
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.codec = codec
|
||||
self.expires_on = expires_on
|
||||
|
||||
@transaction(on_error=partial(on_error, reraise=KeyValueUpdateFailedError))
|
||||
def run(self) -> Optional[Key]:
|
||||
return self.update()
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
def update(self) -> Optional[Key]:
|
||||
filter_ = get_filter(self.resource, self.key)
|
||||
entry: KeyValueEntry = (
|
||||
db.session.query(KeyValueEntry).filter_by(**filter_).first()
|
||||
)
|
||||
if entry:
|
||||
entry.value = self.codec.encode(self.value)
|
||||
entry.expires_on = self.expires_on
|
||||
entry.changed_on = datetime.now()
|
||||
entry.changed_by_fk = get_user_id()
|
||||
db.session.flush()
|
||||
return Key(id=entry.id, uuid=entry.uuid)
|
||||
|
||||
return None
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
from superset.key_value.exceptions import (
|
||||
KeyValueCreateFailedError,
|
||||
KeyValueUpsertFailedError,
|
||||
)
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
|
||||
from superset.key_value.utils import get_filter
|
||||
from superset.utils.core import get_user_id
|
||||
from superset.utils.decorators import on_error, transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpsertKeyValueCommand(BaseCommand):
|
||||
resource: KeyValueResource
|
||||
value: Any
|
||||
key: Union[int, UUID]
|
||||
codec: KeyValueCodec
|
||||
expires_on: Optional[datetime]
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
resource: KeyValueResource,
|
||||
key: Union[int, UUID],
|
||||
value: Any,
|
||||
codec: KeyValueCodec,
|
||||
expires_on: Optional[datetime] = None,
|
||||
):
|
||||
"""
|
||||
Upsert a key value entry
|
||||
|
||||
:param resource: the resource (dashboard, chart etc)
|
||||
:param key: the key to update
|
||||
:param value: the value to persist in the key-value store
|
||||
:param codec: codec used to encode the value
|
||||
:param expires_on: entry expiration time
|
||||
:return: the key associated with the updated value
|
||||
"""
|
||||
self.resource = resource
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.codec = codec
|
||||
self.expires_on = expires_on
|
||||
|
||||
@transaction(
|
||||
on_error=partial(
|
||||
on_error,
|
||||
catches=(KeyValueCreateFailedError, SQLAlchemyError),
|
||||
reraise=KeyValueUpsertFailedError,
|
||||
),
|
||||
)
|
||||
def run(self) -> Key:
|
||||
return self.upsert()
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
||||
def upsert(self) -> Key:
|
||||
if (
|
||||
entry := db.session.query(KeyValueEntry)
|
||||
.filter_by(**get_filter(self.resource, self.key))
|
||||
.first()
|
||||
):
|
||||
entry.value = self.codec.encode(self.value)
|
||||
entry.expires_on = self.expires_on
|
||||
entry.changed_on = datetime.now()
|
||||
entry.changed_by_fk = get_user_id()
|
||||
return Key(entry.id, entry.uuid)
|
||||
|
||||
return CreateKeyValueCommand(
|
||||
resource=self.resource,
|
||||
value=self.value,
|
||||
codec=self.codec,
|
||||
key=self.key,
|
||||
expires_on=self.expires_on,
|
||||
).run()
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
from superset import db
|
||||
from superset.daos.base import BaseDAO
|
||||
from superset.key_value.exceptions import (
|
||||
KeyValueCreateFailedError,
|
||||
KeyValueUpdateFailedError,
|
||||
)
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
|
||||
from superset.key_value.utils import get_filter
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyValueDAO(BaseDAO[KeyValueEntry]):
|
||||
@staticmethod
|
||||
def get_entry(
|
||||
resource: KeyValueResource,
|
||||
key: Key,
|
||||
) -> KeyValueEntry | None:
|
||||
filter_ = get_filter(resource, key)
|
||||
return db.session.query(KeyValueEntry).filter_by(**filter_).first()
|
||||
|
||||
@classmethod
|
||||
def get_value(
|
||||
cls,
|
||||
resource: KeyValueResource,
|
||||
key: Key,
|
||||
codec: KeyValueCodec,
|
||||
) -> Any:
|
||||
entry = cls.get_entry(resource, key)
|
||||
if not entry or entry.is_expired():
|
||||
return None
|
||||
|
||||
return codec.decode(entry.value)
|
||||
|
||||
@staticmethod
|
||||
def delete_entry(resource: KeyValueResource, key: Key) -> bool:
|
||||
if entry := KeyValueDAO.get_entry(resource, key):
|
||||
db.session.delete(entry)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def delete_expired_entries(resource: KeyValueResource) -> None:
|
||||
(
|
||||
db.session.query(KeyValueEntry)
|
||||
.filter(
|
||||
and_(
|
||||
KeyValueEntry.resource == resource.value,
|
||||
KeyValueEntry.expires_on <= datetime.now(),
|
||||
)
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_entry(
|
||||
resource: KeyValueResource,
|
||||
value: Any,
|
||||
codec: KeyValueCodec,
|
||||
key: Key | None = None,
|
||||
expires_on: datetime | None = None,
|
||||
) -> KeyValueEntry:
|
||||
try:
|
||||
encoded_value = codec.encode(value)
|
||||
except Exception as ex:
|
||||
raise KeyValueCreateFailedError("Unable to encode value") from ex
|
||||
entry = KeyValueEntry(
|
||||
resource=resource.value,
|
||||
value=encoded_value,
|
||||
created_on=datetime.now(),
|
||||
created_by_fk=get_user_id(),
|
||||
expires_on=expires_on,
|
||||
)
|
||||
if key is not None:
|
||||
try:
|
||||
if isinstance(key, UUID):
|
||||
entry.uuid = key
|
||||
else:
|
||||
entry.id = key
|
||||
except ValueError as ex:
|
||||
raise KeyValueCreateFailedError() from ex
|
||||
db.session.add(entry)
|
||||
return entry
|
||||
|
||||
@staticmethod
|
||||
def upsert_entry(
|
||||
resource: KeyValueResource,
|
||||
value: Any,
|
||||
codec: KeyValueCodec,
|
||||
key: Key,
|
||||
expires_on: datetime | None = None,
|
||||
) -> KeyValueEntry:
|
||||
if entry := KeyValueDAO.get_entry(resource, key):
|
||||
entry.value = codec.encode(value)
|
||||
entry.expires_on = expires_on
|
||||
entry.changed_on = datetime.now()
|
||||
entry.changed_by_fk = get_user_id()
|
||||
return entry
|
||||
|
||||
return KeyValueDAO.create_entry(resource, value, codec, key, expires_on)
|
||||
|
||||
@staticmethod
|
||||
def update_entry(
|
||||
resource: KeyValueResource,
|
||||
value: Any,
|
||||
codec: KeyValueCodec,
|
||||
key: Key,
|
||||
expires_on: datetime | None = None,
|
||||
) -> KeyValueEntry:
|
||||
if entry := KeyValueDAO.get_entry(resource, key):
|
||||
entry.value = codec.encode(value)
|
||||
entry.expires_on = expires_on
|
||||
entry.changed_on = datetime.now()
|
||||
entry.changed_by_fk = get_user_id()
|
||||
return entry
|
||||
|
||||
raise KeyValueUpdateFailedError()
|
||||
|
|
@ -21,40 +21,18 @@ import logging
|
|||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, cast, TypeVar, Union
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from superset.distributed_lock.utils import get_key
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
|
||||
from superset.utils import json
|
||||
|
||||
LOCK_EXPIRATION = timedelta(seconds=30)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def serialize(params: dict[str, Any]) -> str:
|
||||
"""
|
||||
Serialize parameters into a string.
|
||||
"""
|
||||
|
||||
T = TypeVar(
|
||||
"T",
|
||||
bound=Union[dict[str, Any], list[Any], int, float, str, bool, None],
|
||||
)
|
||||
|
||||
def sort(obj: T) -> T:
|
||||
if isinstance(obj, dict):
|
||||
return cast(T, {k: sort(v) for k, v in sorted(obj.items())})
|
||||
if isinstance(obj, list):
|
||||
return cast(T, [sort(x) for x in obj])
|
||||
return obj
|
||||
|
||||
return json.dumps(params)
|
||||
|
||||
|
||||
def get_key(namespace: str, **kwargs: Any) -> uuid.UUID:
|
||||
return uuid.uuid5(uuid.uuid5(uuid.NAMESPACE_DNS, namespace), serialize(kwargs))
|
||||
CODEC = JsonKeyValueCodec()
|
||||
LOCK_EXPIRATION = timedelta(seconds=30)
|
||||
RESOURCE = KeyValueResource.LOCK
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -75,28 +53,25 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name
|
|||
:yields: A unique identifier (UUID) for the acquired lock (the KV key).
|
||||
:raises CreateKeyValueDistributedLockFailedException: If the lock is taken.
|
||||
"""
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
from superset.commands.key_value.delete import DeleteKeyValueCommand
|
||||
from superset.commands.key_value.delete_expired import DeleteExpiredKeyValueCommand
|
||||
from superset.commands.distributed_lock.create import CreateDistributedLock
|
||||
from superset.commands.distributed_lock.delete import DeleteDistributedLock
|
||||
from superset.commands.distributed_lock.get import GetDistributedLock
|
||||
|
||||
key = get_key(namespace, **kwargs)
|
||||
value = GetDistributedLock(namespace=namespace, params=kwargs).run()
|
||||
if value:
|
||||
logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
|
||||
raise CreateKeyValueDistributedLockFailedException("Lock already taken")
|
||||
|
||||
logger.debug("Acquiring lock on namespace %s for key %s", namespace, key)
|
||||
try:
|
||||
DeleteExpiredKeyValueCommand(resource=KeyValueResource.LOCK).run()
|
||||
CreateKeyValueCommand(
|
||||
resource=KeyValueResource.LOCK,
|
||||
codec=JsonKeyValueCodec(),
|
||||
key=key,
|
||||
value=True,
|
||||
expires_on=datetime.now() + LOCK_EXPIRATION,
|
||||
).run()
|
||||
CreateDistributedLock(namespace=namespace, params=kwargs).run()
|
||||
except CreateKeyValueDistributedLockFailedException as ex:
|
||||
logger.debug("Lock on namespace %s for key %s already taken", namespace, key)
|
||||
raise CreateKeyValueDistributedLockFailedException("Lock already taken") from ex
|
||||
|
||||
yield key
|
||||
|
||||
DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run()
|
||||
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
|
||||
except KeyValueCreateFailedError as ex:
|
||||
raise CreateKeyValueDistributedLockFailedException(
|
||||
"Error acquiring lock"
|
||||
) from ex
|
||||
yield key
|
||||
DeleteDistributedLock(namespace=namespace, params=kwargs).run()
|
||||
logger.debug("Removed lock on namespace %s for key %s", namespace, key)
|
||||
|
|
@ -14,3 +14,8 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class LockValue(TypedDict):
|
||||
value: bool
|
||||
|
|
@ -14,3 +14,32 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import uuid
|
||||
from typing import Any, cast, TypeVar, Union
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
|
||||
def serialize(params: dict[str, Any]) -> str:
|
||||
"""
|
||||
Serialize parameters into a string.
|
||||
"""
|
||||
|
||||
T = TypeVar(
|
||||
"T",
|
||||
bound=Union[dict[str, Any], list[Any], int, float, str, bool, None],
|
||||
)
|
||||
|
||||
def sort(obj: T) -> T:
|
||||
if isinstance(obj, dict):
|
||||
return cast(T, {k: sort(v) for k, v in sorted(obj.items())})
|
||||
if isinstance(obj, list):
|
||||
return cast(T, [sort(x) for x in obj])
|
||||
return obj
|
||||
|
||||
return json.dumps(params)
|
||||
|
||||
|
||||
def get_key(namespace: str, **kwargs: Any) -> uuid.UUID:
|
||||
return uuid.uuid5(uuid.uuid5(uuid.NAMESPACE_DNS, namespace), serialize(kwargs))
|
||||
|
|
@ -379,6 +379,12 @@ class CreateKeyValueDistributedLockFailedException(Exception):
|
|||
"""
|
||||
|
||||
|
||||
class DeleteKeyValueDistributedLockFailedException(Exception):
|
||||
"""
|
||||
Exception to signalize failure to delete lock.
|
||||
"""
|
||||
|
||||
|
||||
class DatabaseNotFoundException(SupersetErrorException):
|
||||
status = 404
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,10 @@ from uuid import UUID, uuid3
|
|||
|
||||
from flask import current_app, Flask, has_app_context
|
||||
from flask_caching import BaseCache
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset import db
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||
from superset.key_value.types import (
|
||||
KeyValueCodec,
|
||||
|
|
@ -29,6 +32,7 @@ from superset.key_value.types import (
|
|||
PickleKeyValueCodec,
|
||||
)
|
||||
from superset.key_value.utils import get_uuid_namespace
|
||||
from superset.utils.decorators import transaction
|
||||
|
||||
RESOURCE = KeyValueResource.METASTORE_CACHE
|
||||
|
||||
|
|
@ -68,15 +72,6 @@ class SupersetMetastoreCache(BaseCache):
|
|||
def get_key(self, key: str) -> UUID:
|
||||
return uuid3(self.namespace, key)
|
||||
|
||||
@staticmethod
|
||||
def _prune() -> None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.key_value.delete_expired import (
|
||||
DeleteExpiredKeyValueCommand,
|
||||
)
|
||||
|
||||
DeleteExpiredKeyValueCommand(resource=RESOURCE).run()
|
||||
|
||||
def _get_expiry(self, timeout: Optional[int]) -> Optional[datetime]:
|
||||
timeout = self._normalize_timeout(timeout)
|
||||
if timeout is not None and timeout > 0:
|
||||
|
|
@ -84,44 +79,34 @@ class SupersetMetastoreCache(BaseCache):
|
|||
return None
|
||||
|
||||
def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.key_value.upsert import UpsertKeyValueCommand
|
||||
|
||||
UpsertKeyValueCommand(
|
||||
KeyValueDAO.upsert_entry(
|
||||
resource=RESOURCE,
|
||||
key=self.get_key(key),
|
||||
value=value,
|
||||
codec=self.codec,
|
||||
expires_on=self._get_expiry(timeout),
|
||||
).run()
|
||||
)
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
return True
|
||||
|
||||
def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
|
||||
try:
|
||||
self._prune()
|
||||
CreateKeyValueCommand(
|
||||
KeyValueDAO.delete_expired_entries(RESOURCE)
|
||||
KeyValueDAO.create_entry(
|
||||
resource=RESOURCE,
|
||||
value=value,
|
||||
codec=self.codec,
|
||||
key=self.get_key(key),
|
||||
expires_on=self._get_expiry(timeout),
|
||||
).run()
|
||||
)
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
return True
|
||||
except KeyValueCreateFailedError:
|
||||
except (SQLAlchemyError, KeyValueCreateFailedError):
|
||||
db.session.rollback() # pylint: disable=consider-using-transaction
|
||||
return False
|
||||
|
||||
def get(self, key: str) -> Any:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.key_value.get import GetKeyValueCommand
|
||||
|
||||
return GetKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
key=self.get_key(key),
|
||||
codec=self.codec,
|
||||
).run()
|
||||
return KeyValueDAO.get_value(RESOURCE, self.get_key(key), self.codec)
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
entry = self.get(key)
|
||||
|
|
@ -129,8 +114,6 @@ class SupersetMetastoreCache(BaseCache):
|
|||
return True
|
||||
return False
|
||||
|
||||
@transaction()
|
||||
def delete(self, key: str) -> Any:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.key_value.delete import DeleteKeyValueCommand
|
||||
|
||||
return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()
|
||||
return KeyValueDAO.delete_entry(RESOURCE, self.get_key(key))
|
||||
|
|
|
|||
|
|
@ -18,8 +18,10 @@
|
|||
from typing import Any, Optional
|
||||
from uuid import uuid3
|
||||
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
|
||||
from superset.key_value.utils import get_uuid_namespace, random_key
|
||||
from superset.utils.decorators import transaction
|
||||
|
||||
RESOURCE = KeyValueResource.APP
|
||||
NAMESPACE = get_uuid_namespace("")
|
||||
|
|
@ -27,24 +29,14 @@ CODEC = JsonKeyValueCodec()
|
|||
|
||||
|
||||
def get_shared_value(key: SharedKey) -> Optional[Any]:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.key_value.get import GetKeyValueCommand
|
||||
|
||||
uuid_key = uuid3(NAMESPACE, key)
|
||||
return GetKeyValueCommand(RESOURCE, key=uuid_key, codec=CODEC).run()
|
||||
return KeyValueDAO.get_value(RESOURCE, uuid_key, CODEC)
|
||||
|
||||
|
||||
@transaction()
|
||||
def set_shared_value(key: SharedKey, value: Any) -> None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
|
||||
uuid_key = uuid3(NAMESPACE, key)
|
||||
CreateKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
value=value,
|
||||
key=uuid_key,
|
||||
codec=CODEC,
|
||||
).run()
|
||||
KeyValueDAO.create_entry(RESOURCE, value, CODEC, uuid_key)
|
||||
|
||||
|
||||
def get_permalink_salt(key: SharedKey) -> str:
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, Union
|
||||
from uuid import UUID
|
||||
|
||||
from marshmallow import Schema, ValidationError
|
||||
|
|
@ -31,11 +30,7 @@ from superset.key_value.exceptions import (
|
|||
)
|
||||
from superset.utils.backports import StrEnum
|
||||
|
||||
|
||||
@dataclass
|
||||
class Key:
|
||||
id: int | None
|
||||
uuid: UUID | None
|
||||
Key = Union[int, UUID]
|
||||
|
||||
|
||||
class KeyValueFilter(TypedDict, total=False):
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import hashids
|
|||
from flask_babel import gettext as _
|
||||
|
||||
from superset.key_value.exceptions import KeyValueParseKeyError
|
||||
from superset.key_value.types import KeyValueFilter, KeyValueResource
|
||||
from superset.key_value.types import Key, KeyValueFilter, KeyValueResource
|
||||
from superset.utils.json import json_dumps_w_dates
|
||||
|
||||
HASHIDS_MIN_LENGTH = 11
|
||||
|
|
@ -35,7 +35,7 @@ def random_key() -> str:
|
|||
return token_urlsafe(48)
|
||||
|
||||
|
||||
def get_filter(resource: KeyValueResource, key: int | UUID) -> KeyValueFilter:
|
||||
def get_filter(resource: KeyValueResource, key: Key) -> KeyValueFilter:
|
||||
try:
|
||||
filter_: KeyValueFilter = {"resource": resource.value}
|
||||
if isinstance(key, UUID):
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ from flask import current_app, url_for
|
|||
from marshmallow import EXCLUDE, fields, post_load, Schema
|
||||
|
||||
from superset import db
|
||||
from superset.distributed_lock import KeyValueDistributedLock
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.superset_typing import OAuth2ClientConfig, OAuth2State
|
||||
from superset.utils.lock import KeyValueDistributedLock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
|
|
|||
|
|
@ -133,11 +133,11 @@ class TestCreatePermalinkDataCommand(SupersetTestCase):
|
|||
assert cache_data.get("datasource") == datasource
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
@patch("superset.commands.key_value.get.GetKeyValueCommand.run")
|
||||
@patch("superset.daos.key_value.KeyValueDAO.get_value")
|
||||
@patch("superset.commands.explore.permalink.get.decode_permalink_id")
|
||||
@pytest.mark.usefixtures("create_dataset", "create_slice")
|
||||
def test_get_permalink_command_with_old_dataset_key(
|
||||
self, decode_id_mock, get_kv_command_mock, mock_g
|
||||
self, decode_id_mock, kv_get_value_mock, mock_g
|
||||
):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
|
||||
|
|
@ -149,13 +149,14 @@ class TestCreatePermalinkDataCommand(SupersetTestCase):
|
|||
)
|
||||
slice = db.session.query(Slice).filter_by(slice_name="slice_name").first()
|
||||
|
||||
datasource_string = f"{dataset.id}__{DatasourceType.TABLE}"
|
||||
datasource_string = f"{dataset.id}__{DatasourceType.TABLE.value}"
|
||||
|
||||
decode_id_mock.return_value = "123456"
|
||||
get_kv_command_mock.return_value = {
|
||||
kv_get_value_mock.return_value = {
|
||||
"chartId": slice.id,
|
||||
"datasetId": dataset.id,
|
||||
"datasource": datasource_string,
|
||||
"datasourceType": DatasourceType.TABLE.value,
|
||||
"state": {
|
||||
"formData": {"datasource": datasource_string, "slice_id": slice.id}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ def test_caching_flow(app_context: AppContext, cache: SupersetMetastoreCache) ->
|
|||
assert cache.has(FIRST_KEY) is False
|
||||
assert cache.add(FIRST_KEY, FIRST_KEY_INITIAL_VALUE) is True
|
||||
assert cache.has(FIRST_KEY) is True
|
||||
assert cache.get(FIRST_KEY) == FIRST_KEY_INITIAL_VALUE
|
||||
cache.set(SECOND_KEY, SECOND_VALUE)
|
||||
assert cache.get(FIRST_KEY) == FIRST_KEY_INITIAL_VALUE
|
||||
assert cache.get(SECOND_KEY) == SECOND_VALUE
|
||||
|
|
|
|||
|
|
@ -1,96 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
from flask.ctx import AppContext
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.key_value.exceptions import KeyValueCreateFailedError
|
||||
from superset.utils import json
|
||||
from superset.utils.core import override_user
|
||||
from tests.integration_tests.key_value.commands.fixtures import (
|
||||
admin, # noqa: F401
|
||||
JSON_CODEC,
|
||||
JSON_VALUE,
|
||||
PICKLE_CODEC,
|
||||
PICKLE_VALUE,
|
||||
RESOURCE,
|
||||
)
|
||||
|
||||
|
||||
def test_create_id_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin):
|
||||
key = CreateKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
value=JSON_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
).run()
|
||||
entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
|
||||
assert json.loads(entry.value) == JSON_VALUE
|
||||
assert entry.created_by_fk == admin.id
|
||||
db.session.delete(entry)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_create_uuid_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin):
|
||||
key = CreateKeyValueCommand(
|
||||
resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC
|
||||
).run()
|
||||
entry = db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).one()
|
||||
assert json.loads(entry.value) == JSON_VALUE
|
||||
assert entry.created_by_fk == admin.id
|
||||
db.session.delete(entry)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_create_fail_json_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
|
||||
with pytest.raises(KeyValueCreateFailedError):
|
||||
CreateKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
value=PICKLE_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
).run()
|
||||
|
||||
|
||||
def test_create_pickle_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
|
||||
from superset.commands.key_value.create import CreateKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin):
|
||||
key = CreateKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
value=PICKLE_VALUE,
|
||||
codec=PICKLE_CODEC,
|
||||
).run()
|
||||
entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one()
|
||||
assert type(pickle.loads(entry.value)) == type(PICKLE_VALUE)
|
||||
assert entry.created_by_fk == admin.id
|
||||
db.session.delete(entry)
|
||||
db.session.commit()
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from flask.ctx import AppContext
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.utils import json
|
||||
from tests.integration_tests.key_value.commands.fixtures import (
|
||||
admin, # noqa: F401
|
||||
JSON_VALUE,
|
||||
RESOURCE,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
ID_KEY = 234
|
||||
UUID_KEY = UUID("5aae143c-44f1-478e-9153-ae6154df333a")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def key_value_entry() -> KeyValueEntry:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
entry = KeyValueEntry(
|
||||
id=ID_KEY,
|
||||
uuid=UUID_KEY,
|
||||
resource=RESOURCE,
|
||||
value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"),
|
||||
)
|
||||
db.session.add(entry)
|
||||
db.session.flush()
|
||||
return entry
|
||||
|
||||
|
||||
def test_delete_id_entry(
|
||||
app_context: AppContext,
|
||||
admin: User, # noqa: F811
|
||||
key_value_entry: KeyValueEntry,
|
||||
) -> None:
|
||||
from superset.commands.key_value.delete import DeleteKeyValueCommand
|
||||
|
||||
assert DeleteKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() is True
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_delete_uuid_entry(
|
||||
app_context: AppContext,
|
||||
admin: User, # noqa: F811
|
||||
key_value_entry: KeyValueEntry,
|
||||
) -> None:
|
||||
from superset.commands.key_value.delete import DeleteKeyValueCommand
|
||||
|
||||
assert DeleteKeyValueCommand(resource=RESOURCE, key=UUID_KEY).run() is True
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_delete_entry_missing(
|
||||
app_context: AppContext,
|
||||
admin: User, # noqa: F811
|
||||
) -> None:
|
||||
from superset.commands.key_value.delete import DeleteKeyValueCommand
|
||||
|
||||
assert DeleteKeyValueCommand(resource=RESOURCE, key=456).run() is False
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.key_value.types import (
|
||||
JsonKeyValueCodec,
|
||||
KeyValueResource,
|
||||
PickleKeyValueCodec,
|
||||
)
|
||||
from superset.utils import json
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
ID_KEY = 123
|
||||
UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc")
|
||||
RESOURCE = KeyValueResource.APP
|
||||
JSON_VALUE = {"foo": "bar"}
|
||||
PICKLE_VALUE = object()
|
||||
JSON_CODEC = JsonKeyValueCodec()
|
||||
PICKLE_CODEC = PickleKeyValueCodec()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def key_value_entry() -> Generator[KeyValueEntry, None, None]:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
entry = KeyValueEntry(
|
||||
id=ID_KEY,
|
||||
uuid=UUID_KEY,
|
||||
resource=RESOURCE,
|
||||
value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"),
|
||||
)
|
||||
db.session.add(entry)
|
||||
db.session.flush()
|
||||
yield entry
|
||||
db.session.delete(entry)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin() -> User:
|
||||
with app.app_context(): # noqa: F841
|
||||
admin = db.session.query(User).filter_by(username="admin").one()
|
||||
return admin
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from flask.ctx import AppContext
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.utils import json
|
||||
from tests.integration_tests.key_value.commands.fixtures import (
|
||||
ID_KEY,
|
||||
JSON_CODEC,
|
||||
JSON_VALUE,
|
||||
key_value_entry, # noqa: F401
|
||||
RESOURCE,
|
||||
UUID_KEY,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
|
||||
def test_get_id_entry(app_context: AppContext, key_value_entry: KeyValueEntry) -> None: # noqa: F811
|
||||
from superset.commands.key_value.get import GetKeyValueCommand
|
||||
|
||||
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, codec=JSON_CODEC).run()
|
||||
assert value == JSON_VALUE
|
||||
|
||||
|
||||
def test_get_uuid_entry(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
) -> None:
|
||||
from superset.commands.key_value.get import GetKeyValueCommand
|
||||
|
||||
value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY, codec=JSON_CODEC).run()
|
||||
assert value == JSON_VALUE
|
||||
|
||||
|
||||
def test_get_id_entry_missing(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
) -> None:
|
||||
from superset.commands.key_value.get import GetKeyValueCommand
|
||||
|
||||
value = GetKeyValueCommand(resource=RESOURCE, key=456, codec=JSON_CODEC).run()
|
||||
assert value is None
|
||||
|
||||
|
||||
def test_get_expired_entry(app_context: AppContext) -> None:
|
||||
from superset.commands.key_value.get import GetKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
entry = KeyValueEntry(
|
||||
id=678,
|
||||
uuid=uuid.uuid4(),
|
||||
resource=RESOURCE,
|
||||
value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"),
|
||||
expires_on=datetime.now() - timedelta(days=1),
|
||||
)
|
||||
db.session.add(entry)
|
||||
db.session.flush()
|
||||
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, codec=JSON_CODEC).run()
|
||||
assert value is None
|
||||
db.session.delete(entry)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_get_future_expiring_entry(app_context: AppContext) -> None:
|
||||
from superset.commands.key_value.get import GetKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
id_ = 789
|
||||
entry = KeyValueEntry(
|
||||
id=id_,
|
||||
uuid=uuid.uuid4(),
|
||||
resource=RESOURCE,
|
||||
value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"),
|
||||
expires_on=datetime.now() + timedelta(days=1),
|
||||
)
|
||||
db.session.add(entry)
|
||||
db.session.flush()
|
||||
value = GetKeyValueCommand(resource=RESOURCE, key=id_, codec=JSON_CODEC).run()
|
||||
assert value == JSON_VALUE
|
||||
db.session.delete(entry)
|
||||
db.session.commit()
|
||||
|
|
@ -1,97 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from flask.ctx import AppContext
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.utils import json
|
||||
from superset.utils.core import override_user
|
||||
from tests.integration_tests.key_value.commands.fixtures import (
|
||||
admin, # noqa: F401
|
||||
ID_KEY,
|
||||
JSON_CODEC,
|
||||
key_value_entry, # noqa: F401
|
||||
RESOURCE,
|
||||
UUID_KEY,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
|
||||
NEW_VALUE = "new value"
|
||||
|
||||
|
||||
def test_update_id_entry(
|
||||
app_context: AppContext,
|
||||
admin: User, # noqa: F811
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
) -> None:
|
||||
from superset.commands.key_value.update import UpdateKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin):
|
||||
key = UpdateKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
key=ID_KEY,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
).run()
|
||||
assert key is not None
|
||||
assert key.id == ID_KEY
|
||||
entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).one()
|
||||
assert json.loads(entry.value) == NEW_VALUE
|
||||
assert entry.changed_by_fk == admin.id
|
||||
|
||||
|
||||
def test_update_uuid_entry(
|
||||
app_context: AppContext,
|
||||
admin: User, # noqa: F811
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
) -> None:
|
||||
from superset.commands.key_value.update import UpdateKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin):
|
||||
key = UpdateKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
key=UUID_KEY,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
).run()
|
||||
assert key is not None
|
||||
assert key.uuid == UUID_KEY
|
||||
entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
|
||||
assert json.loads(entry.value) == NEW_VALUE
|
||||
assert entry.changed_by_fk == admin.id
|
||||
|
||||
|
||||
def test_update_missing_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
|
||||
from superset.commands.key_value.update import UpdateKeyValueCommand
|
||||
|
||||
with override_user(admin):
|
||||
key = UpdateKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
key=456,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
).run()
|
||||
assert key is None
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from flask.ctx import AppContext
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.utils import json
|
||||
from superset.utils.core import override_user
|
||||
from tests.integration_tests.key_value.commands.fixtures import (
|
||||
admin, # noqa: F401
|
||||
ID_KEY,
|
||||
JSON_CODEC,
|
||||
key_value_entry, # noqa: F401
|
||||
RESOURCE,
|
||||
UUID_KEY,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
|
||||
NEW_VALUE = "new value"
|
||||
|
||||
|
||||
def test_upsert_id_entry(
|
||||
app_context: AppContext,
|
||||
admin: User, # noqa: F811
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
) -> None:
|
||||
from superset.commands.key_value.upsert import UpsertKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin):
|
||||
key = UpsertKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
key=ID_KEY,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
).run()
|
||||
assert key is not None
|
||||
assert key.id == ID_KEY
|
||||
entry = db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).one()
|
||||
assert json.loads(entry.value) == NEW_VALUE
|
||||
assert entry.changed_by_fk == admin.id
|
||||
|
||||
|
||||
def test_upsert_uuid_entry(
|
||||
app_context: AppContext,
|
||||
admin: User, # noqa: F811
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
) -> None:
|
||||
from superset.commands.key_value.upsert import UpsertKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin):
|
||||
key = UpsertKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
key=UUID_KEY,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
).run()
|
||||
assert key is not None
|
||||
assert key.uuid == UUID_KEY
|
||||
entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one()
|
||||
assert json.loads(entry.value) == NEW_VALUE
|
||||
assert entry.changed_by_fk == admin.id
|
||||
|
||||
|
||||
def test_upsert_missing_entry(app_context: AppContext, admin: User) -> None: # noqa: F811
|
||||
from superset.commands.key_value.upsert import UpsertKeyValueCommand
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin):
|
||||
key = UpsertKeyValueCommand(
|
||||
resource=RESOURCE,
|
||||
key=456,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
).run()
|
||||
assert key is not None
|
||||
assert key.id == 456
|
||||
db.session.query(KeyValueEntry).filter_by(id=456).delete()
|
||||
db.session.commit()
|
||||
|
|
@ -0,0 +1,395 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, unused-import
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Generator, TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from flask.ctx import AppContext
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.key_value.exceptions import (
|
||||
KeyValueCreateFailedError,
|
||||
KeyValueUpdateFailedError,
|
||||
)
|
||||
from superset.key_value.types import (
|
||||
JsonKeyValueCodec,
|
||||
KeyValueResource,
|
||||
PickleKeyValueCodec,
|
||||
)
|
||||
from superset.utils import json
|
||||
from superset.utils.core import override_user
|
||||
from tests.unit_tests.fixtures.common import admin_user, after_each # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
ID_KEY = 123
|
||||
UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc")
|
||||
RESOURCE = KeyValueResource.APP
|
||||
JSON_VALUE = {"foo": "bar"}
|
||||
PICKLE_VALUE = object()
|
||||
JSON_CODEC = JsonKeyValueCodec()
|
||||
PICKLE_CODEC = PickleKeyValueCodec()
|
||||
NEW_VALUE = {"foo": "baz"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def key_value_entry() -> Generator[KeyValueEntry, None, None]:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
entry = KeyValueEntry(
|
||||
id=ID_KEY,
|
||||
uuid=UUID_KEY,
|
||||
resource=RESOURCE,
|
||||
value=JSON_CODEC.encode(JSON_VALUE),
|
||||
)
|
||||
db.session.add(entry)
|
||||
db.session.flush()
|
||||
yield entry
|
||||
|
||||
|
||||
def test_create_id_entry(
|
||||
app_context: AppContext,
|
||||
admin_user: User, # noqa: F811
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin_user):
|
||||
created_entry = KeyValueDAO.create_entry(
|
||||
resource=RESOURCE,
|
||||
value=JSON_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
)
|
||||
db.session.flush()
|
||||
found_entry = (
|
||||
db.session.query(KeyValueEntry).filter_by(id=created_entry.id).one()
|
||||
)
|
||||
assert json.loads(found_entry.value) == JSON_VALUE
|
||||
assert found_entry.created_by_fk == admin_user.id
|
||||
|
||||
|
||||
def test_create_uuid_entry(
|
||||
app_context: AppContext,
|
||||
admin_user: User, # noqa: F811
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin_user):
|
||||
created_entry = KeyValueDAO.create_entry(
|
||||
resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC
|
||||
)
|
||||
db.session.flush()
|
||||
|
||||
found_entry = (
|
||||
db.session.query(KeyValueEntry).filter_by(uuid=created_entry.uuid).one()
|
||||
)
|
||||
assert json.loads(found_entry.value) == JSON_VALUE
|
||||
assert found_entry.created_by_fk == admin_user.id
|
||||
|
||||
|
||||
def test_create_fail_json_entry(
|
||||
app_context: AppContext,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
with pytest.raises(KeyValueCreateFailedError):
|
||||
KeyValueDAO.create_entry(
|
||||
resource=RESOURCE,
|
||||
value=PICKLE_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
)
|
||||
|
||||
|
||||
def test_create_pickle_entry(
|
||||
app_context: AppContext,
|
||||
admin_user: User, # noqa: F811
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
with override_user(admin_user):
|
||||
created_entry = KeyValueDAO.create_entry(
|
||||
resource=RESOURCE,
|
||||
value=PICKLE_VALUE,
|
||||
codec=PICKLE_CODEC,
|
||||
)
|
||||
db.session.flush()
|
||||
found_entry = (
|
||||
db.session.query(KeyValueEntry).filter_by(id=created_entry.id).one()
|
||||
)
|
||||
assert type(pickle.loads(found_entry.value)) == type(PICKLE_VALUE)
|
||||
assert found_entry.created_by_fk == admin_user.id
|
||||
|
||||
|
||||
def test_get_value(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
value = KeyValueDAO.get_value(
|
||||
resource=RESOURCE,
|
||||
key=key_value_entry.id,
|
||||
codec=JSON_CODEC,
|
||||
)
|
||||
assert value == JSON_VALUE
|
||||
|
||||
|
||||
def test_get_id_entry(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=key_value_entry.id)
|
||||
assert found_entry is not None
|
||||
assert found_entry.id == key_value_entry.id
|
||||
|
||||
|
||||
def test_get_uuid_entry(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=key_value_entry.uuid)
|
||||
assert found_entry is not None
|
||||
assert JSON_CODEC.decode(found_entry.value) == JSON_VALUE
|
||||
|
||||
|
||||
def test_get_id_entry_missing(
|
||||
app_context: AppContext,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
entry = KeyValueDAO.get_entry(resource=RESOURCE, key=456)
|
||||
assert entry is None
|
||||
|
||||
|
||||
def test_get_expired_entry(
|
||||
app_context: AppContext,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
created_entry = KeyValueDAO.create_entry(
|
||||
resource=RESOURCE,
|
||||
value=JSON_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
key=ID_KEY,
|
||||
expires_on=datetime.now() - timedelta(days=1),
|
||||
)
|
||||
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=created_entry.id)
|
||||
assert found_entry is not None
|
||||
assert found_entry.is_expired() is True
|
||||
|
||||
|
||||
def test_get_future_expiring_entry(
|
||||
app_context: AppContext,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
created_entry = KeyValueDAO.create_entry(
|
||||
resource=RESOURCE,
|
||||
value=JSON_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
key=ID_KEY,
|
||||
expires_on=datetime.now() + timedelta(days=1),
|
||||
)
|
||||
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=created_entry.id)
|
||||
assert found_entry is not None
|
||||
assert found_entry.is_expired() is False
|
||||
|
||||
|
||||
def test_update_id_entry(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
admin_user: User, # noqa: F811
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
with override_user(admin_user):
|
||||
updated_entry = KeyValueDAO.update_entry(
|
||||
resource=RESOURCE,
|
||||
key=ID_KEY,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
)
|
||||
db.session.flush()
|
||||
assert updated_entry is not None
|
||||
assert JSON_CODEC.decode(updated_entry.value) == NEW_VALUE
|
||||
assert updated_entry.id == ID_KEY
|
||||
assert updated_entry.uuid == UUID_KEY
|
||||
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY)
|
||||
assert found_entry is not None
|
||||
assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE
|
||||
assert found_entry.changed_by_fk == admin_user.id
|
||||
|
||||
|
||||
def test_update_uuid_entry(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
admin_user: User, # noqa: F811
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
with override_user(admin_user):
|
||||
updated_entry = KeyValueDAO.update_entry(
|
||||
resource=RESOURCE,
|
||||
key=UUID_KEY,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
)
|
||||
db.session.flush()
|
||||
assert updated_entry is not None
|
||||
assert JSON_CODEC.decode(updated_entry.value) == NEW_VALUE
|
||||
assert updated_entry.id == ID_KEY
|
||||
assert updated_entry.uuid == UUID_KEY
|
||||
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=UUID_KEY)
|
||||
assert found_entry is not None
|
||||
assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE
|
||||
assert found_entry.changed_by_fk == admin_user.id
|
||||
|
||||
|
||||
def test_update_missing_entry(
|
||||
app_context: AppContext,
|
||||
admin_user: User, # noqa: F811
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
with override_user(admin_user):
|
||||
with pytest.raises(KeyValueUpdateFailedError):
|
||||
KeyValueDAO.update_entry(
|
||||
resource=RESOURCE,
|
||||
key=456,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
)
|
||||
|
||||
|
||||
def test_upsert_id_entry(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
admin_user: User, # noqa: F811
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
with override_user(admin_user):
|
||||
entry = KeyValueDAO.upsert_entry(
|
||||
resource=RESOURCE,
|
||||
key=ID_KEY,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
)
|
||||
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY)
|
||||
assert found_entry is not None
|
||||
assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE
|
||||
assert entry.changed_by_fk == admin_user.id
|
||||
|
||||
|
||||
def test_upsert_uuid_entry(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry, # noqa: F811
|
||||
admin_user: User, # noqa: F811
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
with override_user(admin_user):
|
||||
entry = KeyValueDAO.upsert_entry(
|
||||
resource=RESOURCE,
|
||||
key=UUID_KEY,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
)
|
||||
db.session.flush()
|
||||
assert entry is not None
|
||||
assert entry.id == ID_KEY
|
||||
assert entry.uuid == UUID_KEY
|
||||
found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=UUID_KEY)
|
||||
assert found_entry is not None
|
||||
assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE
|
||||
assert entry.changed_by_fk == admin_user.id
|
||||
|
||||
|
||||
def test_upsert_missing_entry(
|
||||
app_context: AppContext,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY)
|
||||
assert entry is None
|
||||
KeyValueDAO.upsert_entry(
|
||||
resource=RESOURCE,
|
||||
key=ID_KEY,
|
||||
value=NEW_VALUE,
|
||||
codec=JSON_CODEC,
|
||||
)
|
||||
entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY)
|
||||
assert entry is not None
|
||||
assert JSON_CODEC.decode(entry.value) == NEW_VALUE
|
||||
|
||||
|
||||
def test_delete_id_entry(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
assert KeyValueDAO.delete_entry(resource=RESOURCE, key=ID_KEY) is True
|
||||
|
||||
|
||||
def test_delete_uuid_entry(
|
||||
app_context: AppContext,
|
||||
key_value_entry: KeyValueEntry,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
assert KeyValueDAO.delete_entry(resource=RESOURCE, key=UUID_KEY) is True
|
||||
|
||||
|
||||
def test_delete_entry_missing(
|
||||
app_context: AppContext,
|
||||
after_each: None, # noqa: F811
|
||||
) -> None:
|
||||
from superset.daos.key_value import KeyValueDAO
|
||||
|
||||
assert KeyValueDAO.delete_entry(resource=RESOURCE, key=12345678) is False
|
||||
|
|
@ -22,17 +22,21 @@ from uuid import UUID
|
|||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from superset import db
|
||||
from superset.distributed_lock import KeyValueDistributedLock
|
||||
from superset.distributed_lock.types import LockValue
|
||||
from superset.distributed_lock.utils import get_key
|
||||
from superset.exceptions import CreateKeyValueDistributedLockFailedException
|
||||
from superset.key_value.types import JsonKeyValueCodec
|
||||
from superset.utils.lock import get_key, KeyValueDistributedLock
|
||||
|
||||
LOCK_VALUE: LockValue = {"value": True}
|
||||
MAIN_KEY = get_key("ns", a=1, b=2)
|
||||
OTHER_KEY = get_key("ns2", a=1, b=2)
|
||||
|
||||
|
||||
def _get_lock(key: UUID) -> Any:
|
||||
def _get_lock(key: UUID, session: Session) -> Any:
|
||||
from superset.key_value.models import KeyValueEntry
|
||||
|
||||
entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first()
|
||||
|
|
@ -42,41 +46,56 @@ def _get_lock(key: UUID) -> Any:
|
|||
return JsonKeyValueCodec().decode(entry.value)
|
||||
|
||||
|
||||
def _get_other_session() -> Session:
|
||||
# This session is used to simulate what another worker will find in the metastore
|
||||
# during the locking process.
|
||||
from superset import db
|
||||
|
||||
bind = db.session.get_bind()
|
||||
SessionMaker = sessionmaker(bind=bind)
|
||||
return SessionMaker()
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_happy_path() -> None:
|
||||
"""
|
||||
Test successfully acquiring and returning the distributed lock.
|
||||
|
||||
Note we use a nested transaction to ensure that the cleanup from the outer context
|
||||
manager is correctly invoked, otherwise a partial rollback would occur leaving the
|
||||
database in a fractured state.
|
||||
Note, we're using another session for asserting the lock state in the Metastore
|
||||
to simulate what another worker will observe. Otherwise, there's the risk that
|
||||
the assertions would only be using the non-committed state from the main session.
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
with KeyValueDistributedLock("ns", a=1, b=2) as key:
|
||||
assert key == MAIN_KEY
|
||||
assert _get_lock(key) is True
|
||||
assert _get_lock(OTHER_KEY) is None
|
||||
assert _get_lock(key, session) == LOCK_VALUE
|
||||
assert _get_lock(OTHER_KEY, session) is None
|
||||
|
||||
with db.session.begin_nested():
|
||||
with pytest.raises(CreateKeyValueDistributedLockFailedException):
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
with pytest.raises(CreateKeyValueDistributedLockFailedException):
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
pass
|
||||
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
|
||||
def test_key_value_distributed_lock_expired() -> None:
|
||||
"""
|
||||
Test expiration of the distributed lock
|
||||
|
||||
Note, we're using another session for asserting the lock state in the Metastore
|
||||
to simulate what another worker will observe. Otherwise, there's the risk that
|
||||
the assertions would only be using the non-committed state from the main session.
|
||||
"""
|
||||
session = _get_other_session()
|
||||
|
||||
with freeze_time("2021-01-01"):
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
with KeyValueDistributedLock("ns", a=1, b=2):
|
||||
assert _get_lock(MAIN_KEY) is True
|
||||
assert _get_lock(MAIN_KEY, session) == LOCK_VALUE
|
||||
with freeze_time("2022-01-01"):
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
||||
assert _get_lock(MAIN_KEY) is None
|
||||
assert _get_lock(MAIN_KEY, session) is None
|
||||
|
|
@ -20,12 +20,15 @@ from __future__ import annotations
|
|||
import csv
|
||||
from datetime import datetime
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Any
|
||||
from typing import Any, Generator
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from flask_appbuilder.security.sqla.models import Role, User
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from superset import db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dttm() -> datetime:
|
||||
|
|
@ -73,3 +76,24 @@ def create_columnar_file(
|
|||
df.to_parquet(buffer, index=False)
|
||||
buffer.seek(0)
|
||||
return FileStorage(stream=buffer, filename=filename)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> Generator[User, None, None]:
|
||||
role = db.session.query(Role).filter_by(name="Admin").one()
|
||||
user = User(
|
||||
first_name="Alice",
|
||||
last_name="Admin",
|
||||
email="alice_admin@example.org",
|
||||
username="alice_admin",
|
||||
roles=[role],
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.flush()
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def after_each() -> Generator[None, None, None]:
|
||||
yield
|
||||
db.session.rollback()
|
||||
|
|
|
|||
Loading…
Reference in New Issue