chore(dao/command): Add transaction decorator to try to enforce "unit of work" (#24969)

This commit is contained in:
John Bodley 2024-06-28 12:33:56 -07:00 committed by GitHub
parent a3f0d00714
commit 8fb8199a55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
151 changed files with 681 additions and 916 deletions

View File

@ -240,6 +240,7 @@ ignore_basepython_conflict = true
commands =
superset db upgrade
superset init
superset load-test-users
# use -s to be able to use break pointers.
# no args or tests/* can be passed as an argument to run all tests
pytest -s {posargs}

View File

@ -17,8 +17,10 @@
from collections import defaultdict
from superset import security_manager
from superset.utils.decorators import transaction
@transaction()
def cleanup_permissions() -> None:
# 1. Clean up duplicates.
pvms = security_manager.get_session.query(
@ -29,7 +31,6 @@ def cleanup_permissions() -> None:
for pvm in pvms:
pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm)
duplicates = [v for v in pvms_dict.values() if len(v) > 1]
len(duplicates)
for pvm_list in duplicates:
first_prm = pvm_list[0]
@ -38,7 +39,6 @@ def cleanup_permissions() -> None:
roles = roles.union(pvm.role)
security_manager.get_session.delete(pvm)
first_prm.roles = list(roles)
security_manager.get_session.commit()
pvms = security_manager.get_session.query(
security_manager.permissionview_model
@ -52,7 +52,6 @@ def cleanup_permissions() -> None:
for pvm in pvms:
if not (pvm.view_menu and pvm.permission):
security_manager.get_session.delete(pvm)
security_manager.get_session.commit()
pvms = security_manager.get_session.query(
security_manager.permissionview_model
@ -63,7 +62,6 @@ def cleanup_permissions() -> None:
roles = security_manager.get_session.query(security_manager.role_model).all()
for role in roles:
role.permissions = [p for p in role.permissions if p]
security_manager.get_session.commit()
# 4. Delete empty roles from permission view menus
pvms = security_manager.get_session.query(
@ -71,7 +69,6 @@ def cleanup_permissions() -> None:
).all()
for pvm in pvms:
pvm.role = [r for r in pvm.role if r]
security_manager.get_session.commit()
cleanup_permissions()

View File

@ -29,6 +29,7 @@ echo "Superset config module: $SUPERSET_CONFIG"
superset db upgrade
superset init
superset load-test-users
echo "Running tests"

View File

@ -113,8 +113,10 @@ class CacheRestApi(BaseSupersetModelRestApi):
delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member
CacheKey.cache_key.in_(cache_keys)
)
db.session.execute(delete_stmt)
db.session.commit()
with db.session.begin_nested():
db.session.execute(delete_stmt)
stats_logger_manager.instance.gauge(
"invalidated_cache", len(cache_keys)
)
@ -125,7 +127,5 @@ class CacheRestApi(BaseSupersetModelRestApi):
)
except SQLAlchemyError as ex: # pragma: no cover
logger.error(ex, exc_info=True)
db.session.rollback()
return self.response_500(str(ex))
db.session.commit()
return self.response(201)

View File

@ -20,6 +20,7 @@ import click
from flask.cli import with_appcontext
import superset.utils.database as database_utils
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__)
@ -89,6 +90,7 @@ def load_examples_run(
@click.command()
@with_appcontext
@transaction()
@click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data")
@click.option("--load-big-data", "-b", is_flag=True, help="Load additional big data")
@click.option(

View File

@ -27,6 +27,7 @@ from flask.cli import FlaskGroup, with_appcontext
from superset import app, appbuilder, cli, security_manager
from superset.cli.lib import normalize_token
from superset.extensions import db
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__)
@ -60,6 +61,7 @@ for load, module_name, is_pkg in pkgutil.walk_packages(
@superset.command()
@with_appcontext
@transaction()
def init() -> None:
"""Inits the Superset application"""
appbuilder.add_permissions(update_perms=True)

View File

@ -22,12 +22,14 @@ from flask.cli import with_appcontext
import superset.utils.database as database_utils
from superset import app, security_manager
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__)
@click.command()
@with_appcontext
@transaction()
def load_test_users() -> None:
"""
Loads admin, alpha, and gamma user for testing purposes
@ -35,15 +37,7 @@ def load_test_users() -> None:
Syncs permissions for those users/roles
"""
print(Fore.GREEN + "Loading a set of users for unit tests")
load_test_users_run()
def load_test_users_run() -> None:
"""
Loads admin, alpha, and gamma user for testing purposes
Syncs permissions for those users/roles
"""
if app.config["TESTING"]:
sm = security_manager
@ -84,4 +78,3 @@ def load_test_users_run() -> None:
sm.find_role(role),
password="general",
)
sm.get_session.commit()

View File

@ -30,6 +30,7 @@ from flask_appbuilder.api import BaseApi
from flask_appbuilder.api.manager import resolver
import superset.utils.database as database_utils
from superset.utils.decorators import transaction
from superset.utils.encrypt import SecretsMigrator
logger = logging.getLogger(__name__)
@ -37,6 +38,7 @@ logger = logging.getLogger(__name__)
@click.command()
@with_appcontext
@transaction()
@click.option("--database_name", "-d", help="Database name to change")
@click.option("--uri", "-u", help="Database URI to change")
@click.option(
@ -53,6 +55,7 @@ def set_database_uri(database_name: str, uri: str, skip_create: bool) -> None:
@click.command()
@with_appcontext
@transaction()
def sync_tags() -> None:
"""Rebuilds special tags (owner, type, favorited by)."""
# pylint: disable=no-member

View File

@ -16,6 +16,7 @@
# under the License.
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@ -30,7 +31,7 @@ from superset.commands.annotation_layer.annotation.exceptions import (
from superset.commands.annotation_layer.exceptions import AnnotationLayerNotFoundError
from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -39,13 +40,10 @@ class CreateAnnotationCommand(BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=AnnotationCreateFailedError))
def run(self) -> Model:
self.validate()
try:
return AnnotationDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationCreateFailedError() from ex
return AnnotationDAO.create(attributes=self._properties)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset.commands.annotation_layer.annotation.exceptions import (
@ -23,8 +24,8 @@ from superset.commands.annotation_layer.annotation.exceptions import (
)
from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.models.annotations import Annotation
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -34,15 +35,11 @@ class DeleteAnnotationCommand(BaseCommand):
self._model_ids = model_ids
self._models: Optional[list[Annotation]] = None
@transaction(on_error=partial(on_error, reraise=AnnotationDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._models
try:
AnnotationDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise AnnotationDeleteFailedError() from ex
AnnotationDAO.delete(self._models)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -16,6 +16,7 @@
# under the License.
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@ -31,8 +32,8 @@ from superset.commands.annotation_layer.annotation.exceptions import (
from superset.commands.annotation_layer.exceptions import AnnotationLayerNotFoundError
from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.models.annotations import Annotation
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -43,16 +44,11 @@ class UpdateAnnotationCommand(BaseCommand):
self._properties = data.copy()
self._model: Optional[Annotation] = None
@transaction(on_error=partial(on_error, reraise=AnnotationUpdateFailedError))
def run(self) -> Model:
self.validate()
assert self._model
try:
annotation = AnnotationDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationUpdateFailedError() from ex
return annotation
return AnnotationDAO.update(self._model, self._properties)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any
from flask_appbuilder.models.sqla import Model
@ -27,7 +28,7 @@ from superset.commands.annotation_layer.exceptions import (
)
from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationLayerDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -36,13 +37,10 @@ class CreateAnnotationLayerCommand(BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=AnnotationLayerCreateFailedError))
def run(self) -> Model:
self.validate()
try:
return AnnotationLayerDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationLayerCreateFailedError() from ex
return AnnotationLayerDAO.create(attributes=self._properties)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset.commands.annotation_layer.exceptions import (
@ -24,8 +25,8 @@ from superset.commands.annotation_layer.exceptions import (
)
from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationLayerDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.models.annotations import AnnotationLayer
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -35,15 +36,11 @@ class DeleteAnnotationLayerCommand(BaseCommand):
self._model_ids = model_ids
self._models: Optional[list[AnnotationLayer]] = None
@transaction(on_error=partial(on_error, reraise=AnnotationLayerDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._models
try:
AnnotationLayerDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise AnnotationLayerDeleteFailedError() from ex
AnnotationLayerDAO.delete(self._models)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@ -28,8 +29,8 @@ from superset.commands.annotation_layer.exceptions import (
)
from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationLayerDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.models.annotations import AnnotationLayer
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -40,16 +41,11 @@ class UpdateAnnotationLayerCommand(BaseCommand):
self._properties = data.copy()
self._model: Optional[AnnotationLayer] = None
@transaction(on_error=partial(on_error, reraise=AnnotationLayerUpdateFailedError))
def run(self) -> Model:
self.validate()
assert self._model
try:
annotation_layer = AnnotationLayerDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationLayerUpdateFailedError() from ex
return annotation_layer
return AnnotationLayerDAO.update(self._model, self._properties)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -16,6 +16,7 @@
# under the License.
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional
from flask import g
@ -33,7 +34,7 @@ from superset.commands.chart.exceptions import (
from superset.commands.utils import get_datasource_by_id
from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -42,15 +43,12 @@ class CreateChartCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=ChartCreateFailedError))
def run(self) -> Model:
self.validate()
try:
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
return ChartDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ChartCreateFailedError() from ex
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
return ChartDAO.create(attributes=self._properties)
def validate(self) -> None:
exceptions = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from flask_babel import lazy_gettext as _
@ -28,10 +29,10 @@ from superset.commands.chart.exceptions import (
ChartNotFoundError,
)
from superset.daos.chart import ChartDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.exceptions import SupersetSecurityException
from superset.models.slice import Slice
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -41,15 +42,11 @@ class DeleteChartCommand(BaseCommand):
self._model_ids = model_ids
self._models: Optional[list[Slice]] = None
@transaction(on_error=partial(on_error, reraise=ChartDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._models
try:
ChartDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise ChartDeleteFailedError() from ex
ChartDAO.delete(self._models)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -77,7 +77,7 @@ def import_chart(
if chart.id is None:
db.session.flush()
if user := get_user():
if (user := get_user()) and user not in chart.owners:
chart.owners.append(user)
return chart

View File

@ -16,6 +16,7 @@
# under the License.
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional
from flask import g
@ -35,10 +36,10 @@ from superset.commands.chart.exceptions import (
from superset.commands.utils import get_datasource_by_id, update_tags, validate_tags
from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAODeleteFailedError, DAOUpdateFailedError
from superset.exceptions import SupersetSecurityException
from superset.models.slice import Slice
from superset.tags.models import ObjectType
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -55,24 +56,20 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
self._properties = data.copy()
self._model: Optional[Slice] = None
@transaction(on_error=partial(on_error, reraise=ChartUpdateFailedError))
def run(self) -> Model:
self.validate()
assert self._model
try:
# Update tags
tags = self._properties.pop("tags", None)
if tags is not None:
update_tags(ObjectType.chart, self._model.id, self._model.tags, tags)
# Update tags
if (tags := self._properties.pop("tags", None)) is not None:
update_tags(ObjectType.chart, self._model.id, self._model.tags, tags)
if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
chart = ChartDAO.update(self._model, self._properties)
except (DAOUpdateFailedError, DAODeleteFailedError) as ex:
logger.exception(ex.exception)
raise ChartUpdateFailedError() from ex
return chart
if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
return ChartDAO.update(self._model, self._properties)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset.commands.base import BaseCommand
@ -23,8 +24,8 @@ from superset.commands.css.exceptions import (
CssTemplateNotFoundError,
)
from superset.daos.css import CssTemplateDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.models.core import CssTemplate
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -34,15 +35,11 @@ class DeleteCssTemplateCommand(BaseCommand):
self._model_ids = model_ids
self._models: Optional[list[CssTemplate]] = None
@transaction(on_error=partial(on_error, reraise=CssTemplateDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._models
try:
CssTemplateDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise CssTemplateDeleteFailedError() from ex
CssTemplateDAO.delete(self._models)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@ -28,23 +29,19 @@ from superset.commands.dashboard.exceptions import (
)
from superset.commands.utils import populate_roles
from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
class CreateDashboardCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any]):
def __init__(self, data: dict[str, Any]) -> None:
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DashboardCreateFailedError))
def run(self) -> Model:
self.validate()
try:
dashboard = DashboardDAO.create(attributes=self._properties, commit=True)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise DashboardCreateFailedError() from ex
return dashboard
return DashboardDAO.create(attributes=self._properties)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from flask_babel import lazy_gettext as _
@ -28,10 +29,10 @@ from superset.commands.dashboard.exceptions import (
DashboardNotFoundError,
)
from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.exceptions import SupersetSecurityException
from superset.models.dashboard import Dashboard
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -41,15 +42,11 @@ class DeleteDashboardCommand(BaseCommand):
self._model_ids = model_ids
self._models: Optional[list[Dashboard]] = None
@transaction(on_error=partial(on_error, reraise=DashboardDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._models
try:
DashboardDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DashboardDeleteFailedError() from ex
DashboardDAO.delete(self._models)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -36,6 +36,7 @@ from superset.utils.dashboard_filter_scopes_converter import (
convert_filter_scopes,
copy_filter_scopes,
)
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__)
@ -311,7 +312,6 @@ def import_dashboards(
for dashboard in data["dashboards"]:
import_dashboard(dashboard, dataset_id_mapping, import_time=import_time)
db.session.commit()
class ImportDashboardsCommand(BaseCommand):
@ -329,6 +329,7 @@ class ImportDashboardsCommand(BaseCommand):
self.contents = contents
self.database_id = database_id
@transaction()
def run(self) -> None:
self.validate()

View File

@ -188,7 +188,7 @@ def import_dashboard(
if dashboard.id is None:
db.session.flush()
if user := get_user():
if (user := get_user()) and user not in dashboard.owners:
dashboard.owners.append(user)
return dashboard

View File

@ -15,18 +15,22 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
from superset.commands.key_value.upsert import UpsertKeyValueCommand
from superset.daos.dashboard import DashboardDAO
from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError
from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.key_value.exceptions import KeyValueCodecEncodeException
from superset.key_value.exceptions import (
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
)
from superset.key_value.utils import encode_permalink_key, get_deterministic_uuid
from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -47,29 +51,33 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
self.dashboard_id = dashboard_id
self.state = state
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
SQLAlchemyError,
),
reraise=DashboardPermalinkCreateFailedError,
),
)
def run(self) -> str:
self.validate()
try:
dashboard = DashboardDAO.get_by_id_or_slug(self.dashboard_id)
value = {
"dashboardId": str(dashboard.uuid),
"state": self.state,
}
user_id = get_user_id()
key = UpsertKeyValueCommand(
resource=self.resource,
key=get_deterministic_uuid(self.salt, (user_id, value)),
value=value,
codec=self.codec,
).run()
assert key.id # for type checks
db.session.commit()
return encode_permalink_key(key=key.id, salt=self.salt)
except KeyValueCodecEncodeException as ex:
raise DashboardPermalinkCreateFailedError(str(ex)) from ex
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise DashboardPermalinkCreateFailedError() from ex
dashboard = DashboardDAO.get_by_id_or_slug(self.dashboard_id)
value = {
"dashboardId": str(dashboard.uuid),
"state": self.state,
}
user_id = get_user_id()
key = UpsertKeyValueCommand(
resource=self.resource,
key=get_deterministic_uuid(self.salt, (user_id, value)),
value=value,
codec=self.codec,
).run()
assert key.id # for type checks
return encode_permalink_key(key=key.id, salt=self.salt)
def validate(self) -> None:
pass

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@ -31,12 +32,11 @@ from superset.commands.dashboard.exceptions import (
)
from superset.commands.utils import populate_roles, update_tags, validate_tags
from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAODeleteFailedError, DAOUpdateFailedError
from superset.exceptions import SupersetSecurityException
from superset.extensions import db
from superset.models.dashboard import Dashboard
from superset.tags.models import ObjectType
from superset.utils import json
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -47,29 +47,22 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
self._properties = data.copy()
self._model: Optional[Dashboard] = None
@transaction(on_error=partial(on_error, reraise=DashboardUpdateFailedError))
def run(self) -> Model:
self.validate()
assert self._model
try:
# Update tags
tags = self._properties.pop("tags", None)
if tags is not None:
update_tags(
ObjectType.dashboard, self._model.id, self._model.tags, tags
)
# Update tags
if (tags := self._properties.pop("tags", None)) is not None:
update_tags(ObjectType.dashboard, self._model.id, self._model.tags, tags)
dashboard = DashboardDAO.update(self._model, self._properties)
if self._properties.get("json_metadata"):
DashboardDAO.set_dash_metadata(
dashboard,
data=json.loads(self._properties.get("json_metadata", "{}")),
)
dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
if self._properties.get("json_metadata"):
dashboard = DashboardDAO.set_dash_metadata(
dashboard,
data=json.loads(self._properties.get("json_metadata", "{}")),
commit=False,
)
db.session.commit()
except (DAOUpdateFailedError, DAODeleteFailedError) as ex:
logger.exception(ex.exception)
raise DashboardUpdateFailedError() from ex
return dashboard
def validate(self) -> None:

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask import current_app
@ -39,11 +40,11 @@ from superset.commands.database.ssh_tunnel.exceptions import (
)
from superset.commands.database.test_connection import TestConnectionDatabaseCommand
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager
from superset.extensions import event_logger, security_manager
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"]
@ -53,6 +54,7 @@ class CreateDatabaseCommand(BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError))
def run(self) -> Model:
self.validate()
@ -96,8 +98,6 @@ class CreateDatabaseCommand(BaseCommand):
database, ssh_tunnel_properties
).run()
db.session.commit()
# add catalog/schema permissions
if database.db_engine_spec.supports_catalog:
catalogs = database.get_all_catalog_names(
@ -121,14 +121,12 @@ class CreateDatabaseCommand(BaseCommand):
except Exception: # pylint: disable=broad-except
logger.warning("Error processing catalog '%s'", catalog)
continue
except (
SSHTunnelInvalidError,
SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
@ -136,11 +134,9 @@ class CreateDatabaseCommand(BaseCommand):
# So we can show the original message
raise
except (
DAOCreateFailedError,
DatabaseInvalidError,
Exception,
) as ex:
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__,
@ -198,6 +194,6 @@ class CreateDatabaseCommand(BaseCommand):
raise exception
def _create_database(self) -> Database:
database = DatabaseDAO.create(attributes=self._properties, commit=False)
database = DatabaseDAO.create(attributes=self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
return database

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from flask_babel import lazy_gettext as _
@ -27,9 +28,9 @@ from superset.commands.database.exceptions import (
DatabaseNotFoundError,
)
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -39,15 +40,11 @@ class DeleteDatabaseCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[Database] = None
@transaction(on_error=partial(on_error, reraise=DatabaseDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._model
try:
DatabaseDAO.delete([self._model])
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatabaseDeleteFailedError() from ex
DatabaseDAO.delete([self._model])
def validate(self) -> None:
# Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@ -28,10 +29,10 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelRequiredFieldValidationError,
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.utils import make_url_safe
from superset.extensions import event_logger
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -44,6 +45,7 @@ class CreateSSHTunnelCommand(BaseCommand):
self._properties["database"] = database
self._database = database
@transaction(on_error=partial(on_error, reraise=SSHTunnelCreateFailedError))
def run(self) -> Model:
"""
Create an SSH tunnel.
@ -53,11 +55,8 @@ class CreateSSHTunnelCommand(BaseCommand):
:raises SSHTunnelInvalidError: If the configuration are invalid
"""
try:
self.validate()
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
except DAOCreateFailedError as ex:
raise SSHTunnelCreateFailedError() from ex
self.validate()
return SSHTunnelDAO.create(attributes=self._properties)
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset import is_feature_enabled
@ -25,8 +26,8 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelNotFoundError,
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -36,16 +37,13 @@ class DeleteSSHTunnelCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
@transaction(on_error=partial(on_error, reraise=SSHTunnelDeleteFailedError))
def run(self) -> None:
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
self.validate()
assert self._model
try:
SSHTunnelDAO.delete([self._model])
except DAODeleteFailedError as ex:
raise SSHTunnelDeleteFailedError() from ex
SSHTunnelDAO.delete([self._model])
def validate(self) -> None:
# Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@ -28,9 +29,9 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelUpdateFailedError,
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -41,25 +42,23 @@ class UpdateSSHTunnelCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
@transaction(on_error=partial(on_error, reraise=SSHTunnelUpdateFailedError))
def run(self) -> Optional[Model]:
self.validate()
try:
if self._model is None:
return None
# unset password if private key is provided
if self._properties.get("private_key"):
self._properties["password"] = None
if self._model is None:
return None
# unset private key and password if password is provided
if self._properties.get("password"):
self._properties["private_key"] = None
self._properties["private_key_password"] = None
# unset password if private key is provided
if self._properties.get("private_key"):
self._properties["password"] = None
tunnel = SSHTunnelDAO.update(self._model, self._properties)
return tunnel
except DAOUpdateFailedError as ex:
raise SSHTunnelUpdateFailedError() from ex
# unset private key and password if password is provided
if self._properties.get("password"):
self._properties["private_key"] = None
self._properties["private_key_password"] = None
return SSHTunnelDAO.update(self._model, self._properties)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -18,6 +18,7 @@
from __future__ import annotations
import logging
from functools import partial
from typing import Any
from flask_appbuilder.models.sqla import Model
@ -34,16 +35,14 @@ from superset.commands.database.exceptions import (
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelError,
SSHTunnelingNotEnabledError,
)
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.extensions import db
from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -56,6 +55,7 @@ class UpdateDatabaseCommand(BaseCommand):
self._model_id = model_id
self._model: Database | None = None
@transaction(on_error=partial(on_error, reraise=DatabaseUpdateFailedError))
def run(self) -> Model:
self._model = DatabaseDAO.find_by_id(self._model_id)
@ -76,21 +76,10 @@ class UpdateDatabaseCommand(BaseCommand):
# since they're name based
original_database_name = self._model.database_name
try:
database = DatabaseDAO.update(
self._model,
self._properties,
commit=False,
)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
except SSHTunnelError: # pylint: disable=try-except-raise
# allow exception to bubble for debugbing information
raise
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
raise DatabaseUpdateFailedError() from ex
database = DatabaseDAO.update(self._model, self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
return database
def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
@ -101,7 +90,6 @@ class UpdateDatabaseCommand(BaseCommand):
return None
if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError()
current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
@ -131,13 +119,13 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
return database.get_all_catalog_names(
force=True,
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
def _get_schema_names(
@ -152,6 +140,7 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support.
"""
try:
return database.get_all_schema_names(
force=True,
@ -159,7 +148,6 @@ class UpdateDatabaseCommand(BaseCommand):
ssh_tunnel=ssh_tunnel,
)
except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex
def _refresh_catalogs(
@ -225,8 +213,6 @@ class UpdateDatabaseCommand(BaseCommand):
schemas,
)
db.session.commit()
def _refresh_schemas(
self,
database: Database,

View File

@ -16,11 +16,11 @@
# under the License.
import logging
from abc import abstractmethod
from functools import partial
from typing import Any, Optional, TypedDict
import pandas as pd
from flask_babel import lazy_gettext as _
from sqlalchemy.exc import SQLAlchemyError
from werkzeug.datastructures import FileStorage
from superset import db
@ -37,6 +37,7 @@ from superset.daos.database import DatabaseDAO
from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils.core import get_user
from superset.utils.decorators import on_error, transaction
from superset.views.database.validators import schema_allows_file_upload
logger = logging.getLogger(__name__)
@ -144,6 +145,7 @@ class UploadCommand(BaseCommand):
self._file = file
self._reader = reader
@transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed))
def run(self) -> None:
self.validate()
if not self._model:
@ -172,12 +174,6 @@ class UploadCommand(BaseCommand):
sqla_table.fetch_metadata()
try:
db.session.commit()
except SQLAlchemyError as ex:
db.session.rollback()
raise DatabaseUploadSaveMetadataFailed() from ex
def validate(self) -> None:
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset import security_manager
@ -26,8 +27,8 @@ from superset.commands.dataset.columns.exceptions import (
)
from superset.connectors.sqla.models import TableColumn
from superset.daos.dataset import DatasetColumnDAO, DatasetDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.exceptions import SupersetSecurityException
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -38,15 +39,11 @@ class DeleteDatasetColumnCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[TableColumn] = None
@transaction(on_error=partial(on_error, reraise=DatasetColumnDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._model
try:
DatasetColumnDAO.delete([self._model])
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatasetColumnDeleteFailedError() from ex
DatasetColumnDAO.delete([self._model])
def validate(self) -> None:
# Validate/populate model exists

View File

@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand, CreateMixin
from superset.commands.dataset.exceptions import (
@ -31,10 +31,10 @@ from superset.commands.dataset.exceptions import (
TableNotFoundValidationError,
)
from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.exceptions import SupersetSecurityException
from superset.extensions import db, security_manager
from superset.extensions import security_manager
from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -43,19 +43,12 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DatasetCreateFailedError))
def run(self) -> Model:
self.validate()
try:
# Creates SqlaTable (Dataset)
dataset = DatasetDAO.create(attributes=self._properties, commit=False)
# Updates columns and metrics from the dataset
dataset.fetch_metadata(commit=False)
db.session.commit()
except (SQLAlchemyError, DAOCreateFailedError) as ex:
logger.warning(ex, exc_info=True)
db.session.rollback()
raise DatasetCreateFailedError() from ex
dataset = DatasetDAO.create(attributes=self._properties)
dataset.fetch_metadata()
return dataset
def validate(self) -> None:

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset import security_manager
@ -26,8 +27,8 @@ from superset.commands.dataset.exceptions import (
)
from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.exceptions import SupersetSecurityException
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -37,15 +38,11 @@ class DeleteDatasetCommand(BaseCommand):
self._model_ids = model_ids
self._models: Optional[list[SqlaTable]] = None
@transaction(on_error=partial(on_error, reraise=DatasetDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._models
try:
DatasetDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatasetDeleteFailedError() from ex
DatasetDAO.delete(self._models)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any
from flask_appbuilder.models.sqla import Model
from flask_babel import gettext as __
from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand, CreateMixin
from superset.commands.dataset.exceptions import (
@ -32,12 +32,12 @@ from superset.commands.dataset.exceptions import (
from superset.commands.exceptions import DatasourceTypeInvalidError
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException
from superset.extensions import db
from superset.models.core import Database
from superset.sql_parse import ParsedQuery, Table
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -47,66 +47,61 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
self._base_model: SqlaTable = SqlaTable()
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DatasetDuplicateFailedError))
def run(self) -> Model:
self.validate()
try:
database_id = self._base_model.database_id
table_name = self._properties["table_name"]
owners = self._properties["owners"]
database = db.session.query(Database).get(database_id)
if not database:
raise SupersetErrorException(
SupersetError(
message=__("The database was not found."),
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
level=ErrorLevel.ERROR,
),
status=404,
)
table = SqlaTable(table_name=table_name, owners=owners)
table.database = database
table.schema = self._base_model.schema
table.template_params = self._base_model.template_params
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
db.session.add(table)
cols = []
for config_ in self._base_model.columns:
column_name = config_.column_name
col = TableColumn(
column_name=column_name,
verbose_name=config_.verbose_name,
expression=config_.expression,
filterable=True,
groupby=True,
is_dttm=config_.is_dttm,
type=config_.type,
description=config_.description,
)
cols.append(col)
table.columns = cols
mets = []
for config_ in self._base_model.metrics:
metric_name = config_.metric_name
met = SqlMetric(
metric_name=metric_name,
verbose_name=config_.verbose_name,
expression=config_.expression,
metric_type=config_.metric_type,
description=config_.description,
)
mets.append(met)
table.metrics = mets
db.session.commit()
except (SQLAlchemyError, DAOCreateFailedError) as ex:
logger.warning(ex, exc_info=True)
db.session.rollback()
raise DatasetDuplicateFailedError() from ex
database_id = self._base_model.database_id
table_name = self._properties["table_name"]
owners = self._properties["owners"]
database = db.session.query(Database).get(database_id)
if not database:
raise SupersetErrorException(
SupersetError(
message=__("The database was not found."),
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
level=ErrorLevel.ERROR,
),
status=404,
)
table = SqlaTable(table_name=table_name, owners=owners)
table.database = database
table.schema = self._base_model.schema
table.template_params = self._base_model.template_params
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
db.session.add(table)
cols = []
for config_ in self._base_model.columns:
column_name = config_.column_name
col = TableColumn(
column_name=column_name,
verbose_name=config_.verbose_name,
expression=config_.expression,
filterable=True,
groupby=True,
is_dttm=config_.is_dttm,
type=config_.type,
description=config_.description,
)
cols.append(col)
table.columns = cols
mets = []
for config_ in self._base_model.metrics:
metric_name = config_.metric_name
met = SqlMetric(
metric_name=metric_name,
verbose_name=config_.verbose_name,
expression=config_.expression,
metric_type=config_.metric_type,
description=config_.description,
)
mets.append(met)
table.metrics = mets
return table
def validate(self) -> None:

View File

@ -34,6 +34,7 @@ from superset.connectors.sqla.models import (
)
from superset.models.core import Database
from superset.utils import json
from superset.utils.decorators import transaction
from superset.utils.dict_import_export import DATABASES_KEY
logger = logging.getLogger(__name__)
@ -211,7 +212,6 @@ def import_from_dict(data: dict[str, Any], sync: Optional[list[str]] = None) ->
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []):
Database.import_from_dict(database, sync=sync)
db.session.commit()
else:
logger.info("Supplied object is not a dictionary.")
@ -240,10 +240,10 @@ class ImportDatasetsCommand(BaseCommand):
if kwargs.get("sync_metrics"):
self.sync.append("metrics")
@transaction()
def run(self) -> None:
self.validate()
# TODO (betodealmeida): add rollback in case of error
for file_name, config in self._configs.items():
logger.info("Importing dataset from file %s", file_name)
if isinstance(config, dict):
@ -260,7 +260,6 @@ class ImportDatasetsCommand(BaseCommand):
)
dataset["database_id"] = database.id
SqlaTable.import_from_dict(dataset, sync=self.sync)
db.session.commit()
def validate(self) -> None:
# ensure all files are YAML

View File

@ -178,7 +178,7 @@ def import_dataset(
if data_uri and (not table_exists or force_data):
load_data(data_uri, dataset, dataset.database)
if user := get_user():
if (user := get_user()) and user not in dataset.owners:
dataset.owners.append(user)
return dataset

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset import security_manager
@ -26,8 +27,8 @@ from superset.commands.dataset.metrics.exceptions import (
)
from superset.connectors.sqla.models import SqlMetric
from superset.daos.dataset import DatasetDAO, DatasetMetricDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.exceptions import SupersetSecurityException
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -38,15 +39,11 @@ class DeleteDatasetMetricCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[SqlMetric] = None
@transaction(on_error=partial(on_error, reraise=DatasetMetricDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._model
try:
DatasetMetricDAO.delete([self._model])
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatasetMetricDeleteFailedError() from ex
DatasetMetricDAO.delete([self._model])
def validate(self) -> None:
# Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from flask_appbuilder.models.sqla import Model
@ -29,6 +30,7 @@ from superset.commands.dataset.exceptions import (
from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO
from superset.exceptions import SupersetSecurityException
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -38,16 +40,12 @@ class RefreshDatasetCommand(BaseCommand):
self._model_id = model_id
self._model: Optional[SqlaTable] = None
@transaction(on_error=partial(on_error, reraise=DatasetRefreshFailedError))
def run(self) -> Model:
self.validate()
if self._model:
try:
self._model.fetch_metadata()
return self._model
except Exception as ex:
logger.exception(ex)
raise DatasetRefreshFailedError() from ex
raise DatasetRefreshFailedError()
assert self._model
self._model.fetch_metadata()
return self._model
def validate(self) -> None:
# Validate/populate model exists

View File

@ -16,10 +16,12 @@
# under the License.
import logging
from collections import Counter
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset import security_manager
from superset.commands.base import BaseCommand, UpdateMixin
@ -39,9 +41,9 @@ from superset.commands.dataset.exceptions import (
)
from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.exceptions import SupersetSecurityException
from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -59,19 +61,20 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
self.override_columns = override_columns
self._properties["override_columns"] = override_columns
@transaction(
on_error=partial(
on_error,
catches=(
SQLAlchemyError,
ValueError,
),
reraise=DatasetUpdateFailedError,
)
)
def run(self) -> Model:
self.validate()
if self._model:
try:
dataset = DatasetDAO.update(
self._model,
attributes=self._properties,
)
return dataset
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise DatasetUpdateFailedError() from ex
raise DatasetUpdateFailedError()
assert self._model
return DatasetDAO.update(self._model, attributes=self._properties)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -15,18 +15,22 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
from superset.commands.key_value.create import CreateKeyValueCommand
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
from superset.explore.utils import check_access as check_chart_access
from superset.key_value.exceptions import KeyValueCodecEncodeException
from superset.key_value.exceptions import (
KeyValueCodecEncodeException,
KeyValueCreateFailedError,
)
from superset.key_value.utils import encode_permalink_key
from superset.utils.core import DatasourceType
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -37,35 +41,39 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
self.datasource: str = state["formData"]["datasource"]
self.state = state
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueCodecEncodeException,
KeyValueCreateFailedError,
SQLAlchemyError,
),
reraise=ExplorePermalinkCreateFailedError,
),
)
def run(self) -> str:
self.validate()
try:
d_id, d_type = self.datasource.split("__")
datasource_id = int(d_id)
datasource_type = DatasourceType(d_type)
check_chart_access(datasource_id, self.chart_id, datasource_type)
value = {
"chartId": self.chart_id,
"datasourceId": datasource_id,
"datasourceType": datasource_type.value,
"datasource": self.datasource,
"state": self.state,
}
command = CreateKeyValueCommand(
resource=self.resource,
value=value,
codec=self.codec,
)
key = command.run()
if key.id is None:
raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
db.session.commit()
return encode_permalink_key(key=key.id, salt=self.salt)
except KeyValueCodecEncodeException as ex:
raise ExplorePermalinkCreateFailedError(str(ex)) from ex
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise ExplorePermalinkCreateFailedError() from ex
d_id, d_type = self.datasource.split("__")
datasource_id = int(d_id)
datasource_type = DatasourceType(d_type)
check_chart_access(datasource_id, self.chart_id, datasource_type)
value = {
"chartId": self.chart_id,
"datasourceId": datasource_id,
"datasourceType": datasource_type.value,
"datasource": self.datasource,
"state": self.state,
}
command = CreateKeyValueCommand(
resource=self.resource,
value=value,
codec=self.codec,
)
key = command.run()
if key.id is None:
raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
return encode_permalink_key(key=key.id, salt=self.salt)
def validate(self) -> None:
pass

View File

@ -32,6 +32,7 @@ from superset.commands.importers.v1.utils import (
)
from superset.daos.base import BaseDAO
from superset.models.core import Database # noqa: F401
from superset.utils.decorators import transaction
class ImportModelsCommand(BaseCommand):
@ -67,18 +68,15 @@ class ImportModelsCommand(BaseCommand):
def _get_uuids(cls) -> set[str]:
return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()}
@transaction()
def run(self) -> None:
self.validate()
# rollback to prevent partial imports
try:
self._import(self._configs, self.overwrite)
db.session.commit()
except CommandException:
db.session.rollback()
raise
except Exception as ex:
db.session.rollback()
raise self.import_error() from ex
def validate(self) -> None: # noqa: F811

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from functools import partial
from typing import Any, Optional
from marshmallow import Schema
@ -44,6 +45,7 @@ from superset.datasets.schemas import ImportV1DatasetSchema
from superset.migrations.shared.native_filters import migrate_dashboard
from superset.models.dashboard import dashboard_slices
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
from superset.utils.decorators import on_error, transaction
class ImportAssetsCommand(BaseCommand):
@ -153,16 +155,16 @@ class ImportAssetsCommand(BaseCommand):
if chart.viz_type == "filter_box":
db.session.delete(chart)
@transaction(
on_error=partial(
on_error,
catches=(Exception,),
reraise=ImportFailedError,
)
)
def run(self) -> None:
self.validate()
# rollback to prevent partial imports
try:
self._import(self._configs)
db.session.commit()
except Exception as ex:
db.session.rollback()
raise ImportFailedError() from ex
self._import(self._configs)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -43,6 +43,7 @@ from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.dashboard import dashboard_slices
from superset.utils.core import get_example_default_schema
from superset.utils.database import get_example_database
from superset.utils.decorators import transaction
class ImportExamplesCommand(ImportModelsCommand):
@ -62,19 +63,17 @@ class ImportExamplesCommand(ImportModelsCommand):
super().__init__(contents, *args, **kwargs)
self.force_data = kwargs.get("force_data", False)
@transaction()
def run(self) -> None:
self.validate()
# rollback to prevent partial imports
try:
self._import(
self._configs,
self.overwrite,
self.force_data,
)
db.session.commit()
except Exception as ex:
db.session.rollback()
raise self.import_error() from ex
@classmethod

View File

@ -16,17 +16,17 @@
# under the License.
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional, Union
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -62,6 +62,7 @@ class CreateKeyValueCommand(BaseCommand):
self.key = key
self.expires_on = expires_on
@transaction(on_error=partial(on_error, reraise=KeyValueCreateFailedError))
def run(self) -> Key:
"""
Persist the value
@ -69,11 +70,8 @@ class CreateKeyValueCommand(BaseCommand):
:return: the key associated with the persisted value
"""
try:
return self.create()
except SQLAlchemyError as ex:
db.session.rollback()
raise KeyValueCreateFailedError() from ex
return self.create()
def validate(self) -> None:
pass
@ -98,6 +96,7 @@ class CreateKeyValueCommand(BaseCommand):
entry.id = self.key
except ValueError as ex:
raise KeyValueCreateFailedError() from ex
db.session.add(entry)
db.session.flush()
return Key(id=entry.id, uuid=entry.uuid)

View File

@ -15,17 +15,17 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Union
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource
from superset.key_value.utils import get_filter
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -45,20 +45,19 @@ class DeleteKeyValueCommand(BaseCommand):
self.resource = resource
self.key = key
@transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError))
def run(self) -> bool:
try:
return self.delete()
except SQLAlchemyError as ex:
db.session.rollback()
raise KeyValueDeleteFailedError() from ex
return self.delete()
def validate(self) -> None:
pass
def delete(self) -> bool:
filter_ = get_filter(self.resource, self.key)
if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first():
if (
entry := db.session.query(KeyValueEntry)
.filter_by(**get_filter(self.resource, self.key))
.first()
):
db.session.delete(entry)
db.session.flush()
return True
return False

View File

@ -16,15 +16,16 @@
# under the License.
import logging
from datetime import datetime
from functools import partial
from sqlalchemy import and_
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -41,12 +42,9 @@ class DeleteExpiredKeyValueCommand(BaseCommand):
"""
self.resource = resource
@transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError))
def run(self) -> None:
try:
self.delete_expired()
except SQLAlchemyError as ex:
db.session.rollback()
raise KeyValueDeleteFailedError() from ex
self.delete_expired()
def validate(self) -> None:
pass
@ -62,4 +60,3 @@ class DeleteExpiredKeyValueCommand(BaseCommand):
)
.delete()
)
db.session.flush()

View File

@ -17,11 +17,10 @@
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional, Union
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueUpdateFailedError
@ -29,6 +28,7 @@ from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.key_value.utils import get_filter
from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -64,12 +64,9 @@ class UpdateKeyValueCommand(BaseCommand):
self.codec = codec
self.expires_on = expires_on
@transaction(on_error=partial(on_error, reraise=KeyValueUpdateFailedError))
def run(self) -> Optional[Key]:
try:
return self.update()
except SQLAlchemyError as ex:
db.session.rollback()
raise KeyValueUpdateFailedError() from ex
return self.update()
def validate(self) -> None:
pass

View File

@ -17,6 +17,7 @@
import logging
from datetime import datetime
from functools import partial
from typing import Any, Optional, Union
from uuid import UUID
@ -33,6 +34,7 @@ from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.key_value.utils import get_filter
from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -68,27 +70,29 @@ class UpsertKeyValueCommand(BaseCommand):
self.codec = codec
self.expires_on = expires_on
@transaction(
on_error=partial(
on_error,
catches=(KeyValueCreateFailedError, SQLAlchemyError),
reraise=KeyValueUpsertFailedError,
),
)
def run(self) -> Key:
try:
return self.upsert()
except (KeyValueCreateFailedError, SQLAlchemyError) as ex:
db.session.rollback()
raise KeyValueUpsertFailedError() from ex
return self.upsert()
def validate(self) -> None:
pass
def upsert(self) -> Key:
filter_ = get_filter(self.resource, self.key)
entry: KeyValueEntry = (
db.session.query(KeyValueEntry).filter_by(**filter_).first()
)
if entry:
if (
entry := db.session.query(KeyValueEntry)
.filter_by(**get_filter(self.resource, self.key))
.first()
):
entry.value = self.codec.encode(self.value)
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
db.session.flush()
return Key(entry.id, entry.uuid)
return CreateKeyValueCommand(

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset.commands.base import BaseCommand
@ -22,9 +23,9 @@ from superset.commands.query.exceptions import (
SavedQueryDeleteFailedError,
SavedQueryNotFoundError,
)
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.query import SavedQueryDAO
from superset.models.dashboard import Dashboard
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -34,15 +35,11 @@ class DeleteSavedQueryCommand(BaseCommand):
self._model_ids = model_ids
self._models: Optional[list[Dashboard]] = None
@transaction(on_error=partial(on_error, reraise=SavedQueryDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._models
try:
SavedQueryDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise SavedQueryDeleteFailedError() from ex
SavedQueryDAO.delete(self._models)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_babel import gettext as _
@ -31,7 +32,6 @@ from superset.commands.report.exceptions import (
ReportScheduleNameUniquenessValidationError,
)
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.daos.report import ReportScheduleDAO
from superset.reports.models import (
ReportCreationMethod,
@ -40,6 +40,7 @@ from superset.reports.models import (
)
from superset.reports.types import ReportScheduleExtra
from superset.utils import json
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -48,13 +49,10 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=ReportScheduleCreateFailedError))
def run(self) -> ReportSchedule:
self.validate()
try:
return ReportScheduleDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ReportScheduleCreateFailedError() from ex
return ReportScheduleDAO.create(attributes=self._properties)
def validate(self) -> None:
"""

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Optional
from superset import security_manager
@ -24,10 +25,10 @@ from superset.commands.report.exceptions import (
ReportScheduleForbiddenError,
ReportScheduleNotFoundError,
)
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.exceptions import SupersetSecurityException
from superset.reports.models import ReportSchedule
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -37,15 +38,11 @@ class DeleteReportScheduleCommand(BaseCommand):
self._model_ids = model_ids
self._models: Optional[list[ReportSchedule]] = None
@transaction(on_error=partial(on_error, reraise=ReportScheduleDeleteFailedError))
def run(self) -> None:
self.validate()
assert self._models
try:
ReportScheduleDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise ReportScheduleDeleteFailedError() from ex
ReportScheduleDAO.delete(self._models)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -69,7 +69,7 @@ from superset.tasks.utils import get_executor
from superset.utils import json
from superset.utils.core import HeaderDataType, override_user
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
from superset.utils.decorators import logs_context
from superset.utils.decorators import logs_context, transaction
from superset.utils.pdf import build_pdf_from_screenshots
from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
from superset.utils.urls import get_url_path
@ -120,7 +120,6 @@ class BaseReportState:
self._report_schedule.last_state = state
self._report_schedule.last_eval_dttm = datetime.utcnow()
db.session.commit()
def create_log(self, error_message: Optional[str] = None) -> None:
"""
@ -138,7 +137,7 @@ class BaseReportState:
uuid=self._execution_id,
)
db.session.add(log)
db.session.commit()
db.session.commit() # pylint: disable=consider-using-transaction
def _get_url(
self,
@ -690,6 +689,7 @@ class ReportScheduleStateMachine: # pylint: disable=too-few-public-methods
self._report_schedule = report_schedule
self._scheduled_dttm = scheduled_dttm
@transaction()
def run(self) -> None:
for state_cls in self.states_cls:
if (self._report_schedule.last_state is None and state_cls.initial) or (
@ -718,6 +718,7 @@ class AsyncExecuteReportScheduleCommand(BaseCommand):
self._scheduled_dttm = scheduled_dttm
self._execution_id = UUID(task_id)
@transaction()
def run(self) -> None:
try:
self.validate()

View File

@ -17,12 +17,14 @@
import logging
from datetime import datetime, timedelta
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.report.exceptions import ReportSchedulePruneLogError
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.reports.models import ReportSchedule
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__)
@ -32,9 +34,7 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
Prunes logs from all report schedules
"""
def __init__(self, worker_context: bool = True):
self._worker_context = worker_context
@transaction()
def run(self) -> None:
self.validate()
prune_errors = []
@ -46,15 +46,15 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
)
try:
row_count = ReportScheduleDAO.bulk_delete_logs(
report_schedule, from_date, commit=False
report_schedule,
from_date,
)
db.session.commit()
logger.info(
"Deleted %s logs for report schedule id: %s",
str(row_count),
str(report_schedule.id),
)
except DAODeleteFailedError as ex:
except SQLAlchemyError as ex:
prune_errors.append(str(ex))
if prune_errors:
raise ReportSchedulePruneLogError(";".join(prune_errors))

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@ -32,11 +33,11 @@ from superset.commands.report.exceptions import (
ReportScheduleUpdateFailedError,
)
from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.daos.report import ReportScheduleDAO
from superset.exceptions import SupersetSecurityException
from superset.reports.models import ReportSchedule, ReportScheduleType, ReportState
from superset.utils import json
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -47,16 +48,10 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand):
self._properties = data.copy()
self._model: Optional[ReportSchedule] = None
@transaction(on_error=partial(on_error, reraise=ReportScheduleUpdateFailedError))
def run(self) -> Model:
self.validate()
assert self._model
try:
report_schedule = ReportScheduleDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise ReportScheduleUpdateFailedError() from ex
return report_schedule
return ReportScheduleDAO.update(self._model, self._properties)
def validate(self) -> None:
"""

View File

@ -23,9 +23,9 @@ from superset.commands.base import BaseCommand
from superset.commands.exceptions import DatasourceNotFoundValidationError
from superset.commands.utils import populate_roles
from superset.connectors.sqla.models import SqlaTable
from superset.daos.exceptions import DAOCreateFailedError
from superset.daos.security import RLSDAO
from superset.extensions import db
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__)
@ -36,13 +36,10 @@ class CreateRLSRuleCommand(BaseCommand):
self._tables = self._properties.get("tables", [])
self._roles = self._properties.get("roles", [])
@transaction()
def run(self) -> Any:
self.validate()
try:
return RLSDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise
return RLSDAO.create(attributes=self._properties)
def validate(self) -> None:
roles = populate_roles(self._roles)

View File

@ -16,15 +16,16 @@
# under the License.
import logging
from functools import partial
from superset.commands.base import BaseCommand
from superset.commands.security.exceptions import (
RLSRuleNotFoundError,
RuleDeleteFailedError,
)
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.security import RLSDAO
from superset.reports.models import ReportSchedule
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -34,13 +35,10 @@ class DeleteRLSRuleCommand(BaseCommand):
self._model_ids = model_ids
self._models: list[ReportSchedule] = []
@transaction(on_error=partial(on_error, reraise=RuleDeleteFailedError))
def run(self) -> None:
self.validate()
try:
RLSDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise RuleDeleteFailedError() from ex
RLSDAO.delete(self._models)
def validate(self) -> None:
# Validate/populate model exists

View File

@ -24,9 +24,9 @@ from superset.commands.exceptions import DatasourceNotFoundValidationError
from superset.commands.security.exceptions import RLSRuleNotFoundError
from superset.commands.utils import populate_roles
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.daos.exceptions import DAOUpdateFailedError
from superset.daos.security import RLSDAO
from superset.extensions import db
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__)
@ -39,17 +39,11 @@ class UpdateRLSRuleCommand(BaseCommand):
self._roles = self._properties.get("roles", [])
self._model: Optional[RowLevelSecurityFilter] = None
@transaction()
def run(self) -> Any:
self.validate()
assert self._model
try:
rule = RLSDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise
return rule
return RLSDAO.update(self._model, self._properties)
def validate(self) -> None:
self._model = RLSDAO.find_by_id(int(self._model_id))

View File

@ -22,10 +22,11 @@ import logging
from typing import Any, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.common.db_query_status import QueryStatus
from superset.daos.exceptions import DAOCreateFailedError
from superset.errors import SupersetErrorType
from superset.exceptions import (
SupersetErrorException,
@ -41,6 +42,7 @@ from superset.sqllab.exceptions import (
)
from superset.sqllab.execution_context_convertor import ExecutionContextConvertor
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.decorators import transaction
if TYPE_CHECKING:
from superset.daos.database import DatabaseDAO
@ -90,6 +92,7 @@ class ExecuteSqlCommand(BaseCommand):
def validate(self) -> None:
pass
@transaction()
def run( # pylint: disable=too-many-statements,useless-suppression
self,
) -> CommandResult:
@ -178,9 +181,22 @@ class ExecuteSqlCommand(BaseCommand):
)
def _save_new_query(self, query: Query) -> None:
"""
Saves the new SQL Lab query.
Committing within a transaction violates the "unit of work" construct, but is
necessary for async querying. The Celery task is defined within the confines
of another command and needs to read a previously committed state given the
`READ COMMITTED` isolation level.
To mitigate said issue, ideally there would be a command to prepare said query
and another to execute it, either in a sync or async manner.
:param query: The SQL Lab query
"""
try:
self._query_dao.create(query)
except DAOCreateFailedError as ex:
except SQLAlchemyError as ex:
raise SqlLabException(
self._execution_context,
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
@ -189,6 +205,8 @@ class ExecuteSqlCommand(BaseCommand):
"Please contact an administrator for further assistance or try again.",
) from ex
db.session.commit() # pylint: disable=consider-using-transaction
def _validate_access(self, query: Query) -> None:
try:
self._access_validator.validate(query)

View File

@ -15,16 +15,17 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from typing import Any
from superset import db, security_manager
from superset import security_manager
from superset.commands.base import BaseCommand, CreateMixin
from superset.commands.tag.exceptions import TagCreateFailedError, TagInvalidError
from superset.commands.tag.utils import to_object_model, to_object_type
from superset.daos.exceptions import DAOCreateFailedError
from superset.daos.tag import TagDAO
from superset.exceptions import SupersetSecurityException
from superset.tags.models import ObjectType, TagType
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -35,20 +36,18 @@ class CreateCustomTagCommand(CreateMixin, BaseCommand):
self._object_id = object_id
self._tags = tags
@transaction(on_error=partial(on_error, reraise=TagCreateFailedError))
def run(self) -> None:
self.validate()
try:
object_type = to_object_type(self._object_type)
if object_type is None:
raise TagCreateFailedError(f"invalid object type {self._object_type}")
TagDAO.create_custom_tagged_objects(
object_type=object_type,
object_id=self._object_id,
tag_names=self._tags,
)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise TagCreateFailedError() from ex
object_type = to_object_type(self._object_type)
if object_type is None:
raise TagCreateFailedError(f"invalid object type {self._object_type}")
TagDAO.create_custom_tagged_objects(
object_type=object_type,
object_id=self._object_id,
tag_names=self._tags,
)
def validate(self) -> None:
exceptions = []
@ -71,27 +70,20 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
self._bulk_create = bulk_create
self._skipped_tagged_objects: set[tuple[str, int]] = set()
@transaction(on_error=partial(on_error, reraise=TagCreateFailedError))
def run(self) -> tuple[set[tuple[str, int]], set[tuple[str, int]]]:
self.validate()
try:
tag_name = self._properties["name"]
tag = TagDAO.get_by_name(tag_name.strip(), TagType.custom)
TagDAO.create_tag_relationship(
objects_to_tag=self._properties.get("objects_to_tag", []),
tag=tag,
bulk_create=self._bulk_create,
)
tag_name = self._properties["name"]
tag = TagDAO.get_by_name(tag_name.strip(), TagType.custom)
TagDAO.create_tag_relationship(
objects_to_tag=self._properties.get("objects_to_tag", []),
tag=tag,
bulk_create=self._bulk_create,
)
tag.description = self._properties.get("description", "")
db.session.commit()
return set(self._properties["objects_to_tag"]), self._skipped_tagged_objects
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise TagCreateFailedError() from ex
tag.description = self._properties.get("description", "")
return set(self._properties["objects_to_tag"]), self._skipped_tagged_objects
def validate(self) -> None:
exceptions = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from functools import partial
from superset.commands.base import BaseCommand
from superset.commands.tag.exceptions import (
@ -25,9 +26,9 @@ from superset.commands.tag.exceptions import (
TagNotFoundError,
)
from superset.commands.tag.utils import to_object_type
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.tag import TagDAO
from superset.tags.models import ObjectType
from superset.utils.decorators import on_error, transaction
from superset.views.base import DeleteMixin
logger = logging.getLogger(__name__)
@ -39,18 +40,15 @@ class DeleteTaggedObjectCommand(DeleteMixin, BaseCommand):
self._object_id = object_id
self._tag = tag
@transaction(on_error=partial(on_error, reraise=TaggedObjectDeleteFailedError))
def run(self) -> None:
self.validate()
try:
object_type = to_object_type(self._object_type)
if object_type is None:
raise TaggedObjectDeleteFailedError(
f"invalid object type {self._object_type}"
)
TagDAO.delete_tagged_object(object_type, self._object_id, self._tag)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise TaggedObjectDeleteFailedError() from ex
object_type = to_object_type(self._object_type)
if object_type is None:
raise TaggedObjectDeleteFailedError(
f"invalid object type {self._object_type}"
)
TagDAO.delete_tagged_object(object_type, self._object_id, self._tag)
def validate(self) -> None:
exceptions = []
@ -92,13 +90,10 @@ class DeleteTagsCommand(DeleteMixin, BaseCommand):
def __init__(self, tags: list[str]):
self._tags = tags
@transaction(on_error=partial(on_error, reraise=TagDeleteFailedError))
def run(self) -> None:
self.validate()
try:
TagDAO.delete_tags(self._tags)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise TagDeleteFailedError() from ex
TagDAO.delete_tags(self._tags)
def validate(self) -> None:
exceptions = []

View File

@ -25,6 +25,7 @@ from superset.commands.tag.exceptions import TagInvalidError, TagNotFoundError
from superset.commands.tag.utils import to_object_type
from superset.daos.tag import TagDAO
from superset.tags.models import Tag
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__)
@ -35,18 +36,17 @@ class UpdateTagCommand(UpdateMixin, BaseCommand):
self._properties = data.copy()
self._model: Optional[Tag] = None
@transaction()
def run(self) -> Model:
self.validate()
if self._model:
self._model.name = self._properties["name"]
TagDAO.create_tag_relationship(
objects_to_tag=self._properties.get("objects_to_tag", []),
tag=self._model,
)
self._model.description = self._properties.get("description")
db.session.add(self._model)
db.session.commit()
assert self._model
self._model.name = self._properties["name"]
TagDAO.create_tag_relationship(
objects_to_tag=self._properties.get("objects_to_tag", []),
tag=self._model,
)
self._model.description = self._properties.get("description")
db.session.add(self._model)
return self._model

View File

@ -16,12 +16,12 @@
# under the License.
import logging
from abc import ABC, abstractmethod
from sqlalchemy.exc import SQLAlchemyError
from functools import partial
from superset.commands.base import BaseCommand
from superset.commands.temporary_cache.exceptions import TemporaryCacheCreateFailedError
from superset.commands.temporary_cache.parameters import CommandParameters
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -30,12 +30,9 @@ class CreateTemporaryCacheCommand(BaseCommand, ABC):
def __init__(self, cmd_params: CommandParameters):
self._cmd_params = cmd_params
@transaction(on_error=partial(on_error, reraise=TemporaryCacheCreateFailedError))
def run(self) -> str:
try:
return self.create(self._cmd_params)
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise TemporaryCacheCreateFailedError() from ex
return self.create(self._cmd_params)
def validate(self) -> None:
pass

View File

@ -16,12 +16,12 @@
# under the License.
import logging
from abc import ABC, abstractmethod
from sqlalchemy.exc import SQLAlchemyError
from functools import partial
from superset.commands.base import BaseCommand
from superset.commands.temporary_cache.exceptions import TemporaryCacheDeleteFailedError
from superset.commands.temporary_cache.parameters import CommandParameters
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -30,12 +30,9 @@ class DeleteTemporaryCacheCommand(BaseCommand, ABC):
def __init__(self, cmd_params: CommandParameters):
self._cmd_params = cmd_params
@transaction(on_error=partial(on_error, reraise=TemporaryCacheDeleteFailedError))
def run(self) -> bool:
try:
return self.delete(self._cmd_params)
except SQLAlchemyError as ex:
logger.exception("Error running delete command")
raise TemporaryCacheDeleteFailedError() from ex
return self.delete(self._cmd_params)
def validate(self) -> None:
pass

View File

@ -16,13 +16,13 @@
# under the License.
import logging
from abc import ABC, abstractmethod
from functools import partial
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand
from superset.commands.temporary_cache.exceptions import TemporaryCacheUpdateFailedError
from superset.commands.temporary_cache.parameters import CommandParameters
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@ -34,12 +34,9 @@ class UpdateTemporaryCacheCommand(BaseCommand, ABC):
):
self._parameters = cmd_params
@transaction(on_error=partial(on_error, reraise=TemporaryCacheUpdateFailedError))
def run(self) -> Optional[str]:
try:
return self.update(self._parameters)
except SQLAlchemyError as ex:
logger.exception("Error running update command")
raise TemporaryCacheUpdateFailedError() from ex
return self.update(self._parameters)
def validate(self) -> None:
pass

View File

@ -1768,11 +1768,10 @@ class SqlaTable(
)
)
def fetch_metadata(self, commit: bool = True) -> MetadataResult:
def fetch_metadata(self) -> MetadataResult:
"""
Fetches the metadata for the table and merges it in
:param commit: should the changes be committed or not.
:return: Tuple with lists of added, removed and modified column names.
"""
new_columns = self.external_metadata()
@ -1850,8 +1849,6 @@ class SqlaTable(
config["SQLA_TABLE_MUTATOR"](self)
db.session.merge(self)
if commit:
db.session.commit()
return results
@classmethod

View File

@ -21,13 +21,8 @@ from typing import Any, Generic, get_args, TypeVar
from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError, StatementError
from sqlalchemy.exc import StatementError
from superset.daos.exceptions import (
DAOCreateFailedError,
DAODeleteFailedError,
DAOUpdateFailedError,
)
from superset.extensions import db
T = TypeVar("T", bound=Model)
@ -127,15 +122,12 @@ class BaseDAO(Generic[T]):
cls,
item: T | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> T:
"""
Create an object from the specified item and/or attributes.
:param item: The object to create
:param attributes: The attributes associated with the object to create
:param commit: Whether to commit the transaction
:raises DAOCreateFailedError: If the creation failed
"""
if not item:
@ -145,15 +137,7 @@ class BaseDAO(Generic[T]):
for key, value in attributes.items():
setattr(item, key, value)
try:
db.session.add(item)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
db.session.add(item)
return item # type: ignore
@classmethod
@ -161,15 +145,12 @@ class BaseDAO(Generic[T]):
cls,
item: T | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> T:
"""
Update an object from the specified item and/or attributes.
:param item: The object to update
:param attributes: The attributes associated with the object to update
:param commit: Whether to commit the transaction
:raises DAOUpdateFailedError: If the updating failed
"""
if not item:
@ -179,19 +160,13 @@ class BaseDAO(Generic[T]):
for key, value in attributes.items():
setattr(item, key, value)
try:
db.session.merge(item)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOUpdateFailedError(exception=ex) from ex
if item not in db.session:
return db.session.merge(item)
return item # type: ignore
@classmethod
def delete(cls, items: list[T], commit: bool = True) -> None:
def delete(cls, items: list[T]) -> None:
"""
Delete the specified items including their associated relationships.
@ -204,17 +179,8 @@ class BaseDAO(Generic[T]):
post-deletion logic.
:param items: The items to delete
:param commit: Whether to commit the transaction
:raises DAODeleteFailedError: If the deletion failed
:see: https://docs.sqlalchemy.org/en/latest/orm/queryguide/dml.html
"""
try:
for item in items:
db.session.delete(item)
if commit:
db.session.commit()
except SQLAlchemyError as ex:
db.session.rollback()
raise DAODeleteFailedError(exception=ex) from ex
for item in items:
db.session.delete(item)

View File

@ -62,7 +62,6 @@ class ChartDAO(BaseDAO[Slice]):
dttm=datetime.now(),
)
)
db.session.commit()
@staticmethod
def remove_favorite(chart: Slice) -> None:
@ -77,4 +76,3 @@ class ChartDAO(BaseDAO[Slice]):
)
if fav:
db.session.delete(fav)
db.session.commit()

View File

@ -179,8 +179,7 @@ class DashboardDAO(BaseDAO[Dashboard]):
dashboard: Dashboard,
data: dict[Any, Any],
old_to_new_slice_ids: dict[int, int] | None = None,
commit: bool = False,
) -> Dashboard:
) -> None:
new_filter_scopes = {}
md = dashboard.params_dict
@ -265,10 +264,6 @@ class DashboardDAO(BaseDAO[Dashboard]):
md["cross_filters_enabled"] = data.get("cross_filters_enabled", True)
dashboard.json_metadata = json.dumps(md)
if commit:
db.session.commit()
return dashboard
@staticmethod
def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]:
ids = [dash.id for dash in dashboards]
@ -321,7 +316,6 @@ class DashboardDAO(BaseDAO[Dashboard]):
dash.params = original_dash.params
cls.set_dash_metadata(dash, metadata, old_to_new_slice_ids)
db.session.add(dash)
db.session.commit()
return dash
@staticmethod
@ -336,7 +330,6 @@ class DashboardDAO(BaseDAO[Dashboard]):
dttm=datetime.now(),
)
)
db.session.commit()
@staticmethod
def remove_favorite(dashboard: Dashboard) -> None:
@ -351,7 +344,6 @@ class DashboardDAO(BaseDAO[Dashboard]):
)
if fav:
db.session.delete(fav)
db.session.commit()
class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
@ -369,7 +361,6 @@ class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
)
embedded.allow_domain_list = ",".join(allowed_domains)
dashboard.embedded = [embedded]
db.session.commit()
return embedded
@classmethod
@ -377,7 +368,6 @@ class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
cls,
item: EmbeddedDashboardDAO | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Any:
"""
Use EmbeddedDashboardDAO.upsert() instead.

View File

@ -42,7 +42,6 @@ class DatabaseDAO(BaseDAO[Database]):
cls,
item: Database | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Database:
"""
Unmask ``encrypted_extra`` before updating.
@ -60,7 +59,7 @@ class DatabaseDAO(BaseDAO[Database]):
attributes["encrypted_extra"],
)
return super().update(item, attributes, commit)
return super().update(item, attributes)
@staticmethod
def validate_uniqueness(database_name: str) -> bool:
@ -174,7 +173,6 @@ class SSHTunnelDAO(BaseDAO[SSHTunnel]):
cls,
item: SSHTunnel | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> SSHTunnel:
"""
Unmask ``password``, ``private_key`` and ``private_key_password`` before updating.
@ -190,7 +188,7 @@ class SSHTunnelDAO(BaseDAO[SSHTunnel]):
attributes.pop("id", None)
attributes = unmask_password_info(attributes, item)
return super().update(item, attributes, commit)
return super().update(item, attributes)
class DatabaseUserOAuth2TokensDAO(BaseDAO[DatabaseUserOAuth2Tokens]):

View File

@ -25,7 +25,6 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.extensions import db
from superset.models.core import Database
from superset.models.dashboard import Dashboard
@ -171,7 +170,6 @@ class DatasetDAO(BaseDAO[SqlaTable]):
cls,
item: SqlaTable | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> SqlaTable:
"""
Updates a Dataset model on the metadata DB
@ -182,21 +180,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
cls.update_columns(
item,
attributes.pop("columns"),
commit=commit,
override_columns=bool(attributes.get("override_columns")),
)
if "metrics" in attributes:
cls.update_metrics(item, attributes.pop("metrics"), commit=commit)
cls.update_metrics(item, attributes.pop("metrics"))
return super().update(item, attributes, commit=commit)
return super().update(item, attributes)
@classmethod
def update_columns(
cls,
model: SqlaTable,
property_columns: list[dict[str, Any]],
commit: bool = True,
override_columns: bool = False,
) -> None:
"""
@ -217,7 +213,7 @@ class DatasetDAO(BaseDAO[SqlaTable]):
if not DatasetDAO.validate_python_date_format(
column["python_date_format"]
):
raise DAOUpdateFailedError(
raise ValueError(
"python_date_format is an invalid date/timestamp format."
)
@ -266,15 +262,11 @@ class DatasetDAO(BaseDAO[SqlaTable]):
)
).delete(synchronize_session="fetch")
if commit:
db.session.commit()
@classmethod
def update_metrics(
cls,
model: SqlaTable,
property_metrics: list[dict[str, Any]],
commit: bool = True,
) -> None:
"""
Creates/updates and/or deletes a list of metrics, based on a
@ -317,9 +309,6 @@ class DatasetDAO(BaseDAO[SqlaTable]):
)
).delete(synchronize_session="fetch")
if commit:
db.session.commit()
@classmethod
def find_dataset_column(cls, dataset_id: int, column_id: int) -> TableColumn | None:
# We want to apply base dataset filters

View File

@ -23,30 +23,6 @@ class DAOException(SupersetException):
"""
class DAOCreateFailedError(DAOException):
"""
DAO Create failed
"""
message = "Create failed"
class DAOUpdateFailedError(DAOException):
"""
DAO Update failed
"""
message = "Update failed"
class DAODeleteFailedError(DAOException):
"""
DAO Delete failed
"""
message = "Delete failed"
class DatasourceTypeNotSupportedError(DAOException):
"""
DAO datasource query source type is not supported

View File

@ -53,7 +53,6 @@ class QueryDAO(BaseDAO[Query]):
for saved_query in related_saved_queries:
saved_query.rows = query.rows
saved_query.last_run = datetime.now()
db.session.commit()
@staticmethod
def save_metadata(query: Query, payload: dict[str, Any]) -> None:
@ -97,7 +96,6 @@ class QueryDAO(BaseDAO[Query]):
query.status = QueryStatus.STOPPED
query.end_time = now_as_float()
db.session.commit()
class SavedQueryDAO(BaseDAO[SavedQuery]):

View File

@ -20,10 +20,7 @@ import logging
from datetime import datetime
from typing import Any
from sqlalchemy.exc import SQLAlchemyError
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.extensions import db
from superset.reports.filters import ReportScheduleFilter
from superset.reports.models import (
@ -137,15 +134,12 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
cls,
item: ReportSchedule | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> ReportSchedule:
"""
Create a report schedule with nested recipients.
:param item: The object to create
:param attributes: The attributes associated with the object to create
:param commit: Whether to commit the transaction
:raises: DAOCreateFailedError: If the creation failed
"""
# TODO(john-bodley): Determine why we need special handling for recipients.
@ -165,22 +159,19 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
for recipient in recipients
]
return super().create(item, attributes, commit)
return super().create(item, attributes)
@classmethod
def update(
cls,
item: ReportSchedule | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> ReportSchedule:
"""
Update a report schedule with nested recipients.
:param item: The object to update
:param attributes: The attributes associated with the object to update
:param commit: Whether to commit the transaction
:raises: DAOUpdateFailedError: If the update failed
"""
# TODO(john-bodley): Determine why we need special handling for recipients.
@ -200,7 +191,7 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
for recipient in recipients
]
return super().update(item, attributes, commit)
return super().update(item, attributes)
@staticmethod
def find_active() -> list[ReportSchedule]:
@ -283,23 +274,12 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
return last_error_email_log if not report_from_last_email else None
@staticmethod
def bulk_delete_logs(
model: ReportSchedule,
from_date: datetime,
commit: bool = True,
) -> int | None:
try:
row_count = (
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.report_schedule == model,
ReportExecutionLog.end_dttm < from_date,
)
.delete(synchronize_session="fetch")
def bulk_delete_logs(model: ReportSchedule, from_date: datetime) -> int | None:
return (
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.report_schedule == model,
ReportExecutionLog.end_dttm < from_date,
)
if commit:
db.session.commit()
return row_count
except SQLAlchemyError as ex:
db.session.rollback()
raise DAODeleteFailedError(str(ex)) from ex
.delete(synchronize_session="fetch")
)

View File

@ -19,12 +19,11 @@ from operator import and_
from typing import Any, Optional
from flask import g
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.exc import NoResultFound
from superset.commands.tag.exceptions import TagNotFoundError
from superset.commands.tag.utils import to_object_type
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.exceptions import MissingUserContextException
from superset.extensions import db
from superset.models.dashboard import Dashboard
@ -75,7 +74,6 @@ class TagDAO(BaseDAO[Tag]):
)
db.session.add_all(tagged_objects)
db.session.commit()
@staticmethod
def delete_tagged_object(
@ -86,9 +84,7 @@ class TagDAO(BaseDAO[Tag]):
"""
tag = TagDAO.find_by_name(tag_name.strip())
if not tag:
raise DAODeleteFailedError(
message=f"Tag with name {tag_name} does not exist."
)
raise NoResultFound(message=f"Tag with name {tag_name} does not exist.")
tagged_object = db.session.query(TaggedObject).filter(
TaggedObject.tag_id == tag.id,
@ -96,17 +92,13 @@ class TagDAO(BaseDAO[Tag]):
TaggedObject.object_id == object_id,
)
if not tagged_object:
raise DAODeleteFailedError(
raise NoResultFound(
message=f'Tagged object with object_id: {object_id} \
object_type: {object_type} \
and tag name: "{tag_name}" could not be found'
)
try:
db.session.delete(tagged_object.one())
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAODeleteFailedError(exception=ex) from ex
db.session.delete(tagged_object.one())
@staticmethod
def delete_tags(tag_names: list[str]) -> None:
@ -117,18 +109,12 @@ class TagDAO(BaseDAO[Tag]):
for name in tag_names:
tag_name = name.strip()
if not TagDAO.find_by_name(tag_name):
raise DAODeleteFailedError(
message=f"Tag with name {tag_name} does not exist."
)
raise NoResultFound(message=f"Tag with name {tag_name} does not exist.")
tags_to_delete.append(tag_name)
tag_objects = db.session.query(Tag).filter(Tag.name.in_(tags_to_delete))
for tag in tag_objects:
try:
db.session.delete(tag)
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAODeleteFailedError(exception=ex) from ex
db.session.delete(tag)
@staticmethod
def get_by_name(name: str, type_: TagType = TagType.custom) -> Tag:
@ -283,21 +269,10 @@ class TagDAO(BaseDAO[Tag]):
) -> None:
"""
Marks a specific tag as a favorite for the current user.
This function will find the tag by the provided id,
create a new UserFavoriteTag object that represents
the user's preference, add that object to the database
session, and commit the session. It uses the currently
authenticated user from the global 'g' object.
Args:
tag_id: The id of the tag that is to be marked as
favorite.
Raises:
Any exceptions raised by the find_by_id function,
the UserFavoriteTag constructor, or the database session's
add and commit methods will propagate up to the caller.
Returns:
None.
:param tag_id: The id of the tag that is to be marked as favorite
"""
tag = TagDAO.find_by_id(tag_id)
user = g.user
@ -307,26 +282,13 @@ class TagDAO(BaseDAO[Tag]):
raise TagNotFoundError()
tag.users_favorited.append(user)
db.session.commit()
@staticmethod
def remove_user_favorite_tag(tag_id: int) -> None:
"""
Removes a tag from the current user's favorite tags.
This function will find the tag by the provided id and remove the tag
from the user's list of favorite tags. It uses the currently authenticated
user from the global 'g' object.
Args:
tag_id: The id of the tag that is to be removed from the favorite tags.
Raises:
Any exceptions raised by the find_by_id function, the database session's
commit method will propagate up to the caller.
Returns:
None.
:param tag_id: The id of the tag that is to be removed from the favorite tags
"""
tag = TagDAO.find_by_id(tag_id)
user = g.user
@ -338,9 +300,6 @@ class TagDAO(BaseDAO[Tag]):
tag.users_favorited.remove(user)
# Commit to save the changes
db.session.commit()
@staticmethod
def favorited_ids(tags: list[Tag]) -> list[int]:
"""
@ -424,5 +383,4 @@ class TagDAO(BaseDAO[Tag]):
object_id,
tag.name,
)
db.session.add_all(tagged_objects)

View File

@ -40,4 +40,3 @@ class UserDAO(BaseDAO[User]):
attrs = UserAttribute(avatar_url=url, user_id=user.id)
user.extra_attributes = [attrs]
db.session.add(attrs)
db.session.commit()

View File

@ -32,7 +32,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
from superset import is_feature_enabled, thumbnail_cache
from superset import db, is_feature_enabled, thumbnail_cache
from superset.charts.schemas import ChartEntityResponseSchema
from superset.commands.dashboard.create import CreateDashboardCommand
from superset.commands.dashboard.delete import DeleteDashboardCommand
@ -1314,7 +1314,13 @@ class DashboardRestApi(BaseSupersetModelRestApi):
"""
try:
body = self.embedded_config_schema.load(request.json)
embedded = EmbeddedDashboardDAO.upsert(dashboard, body["allowed_domains"])
with db.session.begin_nested():
embedded = EmbeddedDashboardDAO.upsert(
dashboard,
body["allowed_domains"],
)
result = self.embedded_response_schema.dump(embedded)
return self.response(200, result=result)
except ValidationError as error:

View File

@ -1410,7 +1410,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
database_id=state["database_id"],
)
if existing:
DatabaseUserOAuth2TokensDAO.delete([existing], commit=True)
DatabaseUserOAuth2TokensDAO.delete([existing])
# store tokens
expiration = datetime.now() + timedelta(seconds=token_response["expires_in"])
@ -1422,7 +1422,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"access_token_expiration": expiration,
"refresh_token": token_response.get("refresh_token"),
},
commit=True,
)
# return blank page that closes itself

View File

@ -455,4 +455,4 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
catalog[table.table] = spreadsheet_url
database.extra = json.dumps(extra)
db.session.add(database)
db.session.commit()
db.session.commit() # pylint: disable=consider-using-transaction

View File

@ -408,7 +408,7 @@ class HiveEngineSpec(PrestoEngineSpec):
logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l)
last_log_line = len(log_lines)
if needs_commit:
db.session.commit()
db.session.commit() # pylint: disable=consider-using-transaction
if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"):
logger.warning(
"HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead"

View File

@ -151,7 +151,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
needs_commit = True
if needs_commit:
db.session.commit()
db.session.commit() # pylint: disable=consider-using-transaction
sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get(
cls.engine, 5
)

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
# pylint: disable=consider-using-transaction,too-many-lines
from __future__ import annotations
import contextlib

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=consider-using-transaction
from __future__ import annotations
import contextlib

View File

@ -65,5 +65,4 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl.description = "BART lines"
tbl.database = database
tbl.filter_select_enabled = True
db.session.commit()
tbl.fetch_metadata()

View File

@ -111,8 +111,6 @@ def load_birth_names(
_set_table_metadata(obj, database)
_add_table_metrics(obj)
db.session.commit()
slices, _ = create_slices(obj)
create_dashboard(slices)
@ -844,5 +842,4 @@ def create_dashboard(slices: list[Slice]) -> Dashboard:
dash.dashboard_title = "USA Births Names"
dash.position_json = json.dumps(pos, indent=4)
dash.slug = "births"
db.session.commit()
return dash

View File

@ -88,7 +88,6 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column("2004").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})"))
db.session.commit()
obj.fetch_metadata()
tbl = obj

View File

@ -52,7 +52,6 @@ def load_css_templates() -> None:
"""
)
obj.css = css
db.session.commit()
obj = db.session.query(CssTemplate).filter_by(template_name="Courier Black").first()
if not obj:
@ -97,4 +96,3 @@ def load_css_templates() -> None:
"""
)
obj.css = css
db.session.commit()

View File

@ -541,4 +541,3 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
dash.dashboard_title = title
dash.slug = slug
dash.slices = slices
db.session.commit()

View File

@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Loads datasets, dashboards and slices in a new superset instance"""
import textwrap
import pandas as pd
@ -79,7 +77,6 @@ def load_energy(
SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
)
db.session.commit()
tbl.fetch_metadata()
slc = Slice(

View File

@ -66,6 +66,5 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
tbl.description = "Random set of flights in the US"
tbl.database = database
tbl.filter_select_enabled = True
db.session.commit()
tbl.fetch_metadata()
print("Done loading table!")

View File

@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Loads datasets, dashboards and slices in a new superset instance"""
import os
from typing import Any
@ -62,7 +60,6 @@ def merge_slice(slc: Slice) -> None:
if o:
db.session.delete(o)
db.session.add(slc)
db.session.commit()
def get_slice_json(defaults: dict[Any, Any], **kwargs: Any) -> str:

View File

@ -97,7 +97,6 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
obj.main_dttm_col = "datetime"
obj.database = database
obj.filter_select_enabled = True
db.session.commit()
obj.fetch_metadata()
tbl = obj

View File

@ -140,4 +140,3 @@ def load_misc_dashboard() -> None:
dash.position_json = json.dumps(pos, indent=4)
dash.slug = DASH_SLUG
dash.slices = slices
db.session.commit()

View File

@ -102,7 +102,6 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
col.python_date_format = dttm_and_expr[0]
col.database_expression = dttm_and_expr[1]
col.is_dttm = True
db.session.commit()
obj.fetch_metadata()
tbl = obj

View File

@ -62,5 +62,4 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->
tbl.description = "Map of Paris"
tbl.database = database
tbl.filter_select_enabled = True
db.session.commit()
tbl.fetch_metadata()

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
from sqlalchemy import DateTime, inspect, String
@ -72,7 +71,6 @@ def load_random_time_series_data(
obj.main_dttm_col = "ds"
obj.database = database
obj.filter_select_enabled = True
db.session.commit()
obj.fetch_metadata()
tbl = obj

View File

@ -64,5 +64,4 @@ def load_sf_population_polygons(
tbl.description = "Population density of San Francisco"
tbl.database = database
tbl.filter_select_enabled = True
db.session.commit()
tbl.fetch_metadata()

View File

@ -14,9 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
import textwrap
from sqlalchemy import inspect
@ -1274,4 +1272,3 @@ def load_supported_charts_dashboard() -> None:
dash.dashboard_title = "Supported Charts Dashboard"
dash.position_json = json.dumps(pos, indent=2)
dash.slug = DASH_SLUG
db.session.commit()

View File

@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Loads datasets, dashboards and slices in a new superset instance"""
import textwrap
from superset import db
@ -558,4 +556,3 @@ def load_tabbed_dashboard(_: bool = False) -> None:
dash.slices = slices
dash.dashboard_title = "Tabbed Dashboard"
dash.slug = slug
db.session.commit()

Some files were not shown because too many files have changed in this diff Show More