chore(dao/command): Add transaction decorator to try to enforce "unit of work" (#24969)

This commit is contained in:
John Bodley 2024-06-28 12:33:56 -07:00 committed by GitHub
parent a3f0d00714
commit 8fb8199a55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
151 changed files with 681 additions and 916 deletions

View File

@ -240,6 +240,7 @@ ignore_basepython_conflict = true
commands = commands =
superset db upgrade superset db upgrade
superset init superset init
superset load-test-users
# use -s to be able to use break pointers. # use -s to be able to use break pointers.
# no args or tests/* can be passed as an argument to run all tests # no args or tests/* can be passed as an argument to run all tests
pytest -s {posargs} pytest -s {posargs}

View File

@ -17,8 +17,10 @@
from collections import defaultdict from collections import defaultdict
from superset import security_manager from superset import security_manager
from superset.utils.decorators import transaction
@transaction()
def cleanup_permissions() -> None: def cleanup_permissions() -> None:
# 1. Clean up duplicates. # 1. Clean up duplicates.
pvms = security_manager.get_session.query( pvms = security_manager.get_session.query(
@ -29,7 +31,6 @@ def cleanup_permissions() -> None:
for pvm in pvms: for pvm in pvms:
pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm) pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm)
duplicates = [v for v in pvms_dict.values() if len(v) > 1] duplicates = [v for v in pvms_dict.values() if len(v) > 1]
len(duplicates)
for pvm_list in duplicates: for pvm_list in duplicates:
first_prm = pvm_list[0] first_prm = pvm_list[0]
@ -38,7 +39,6 @@ def cleanup_permissions() -> None:
roles = roles.union(pvm.role) roles = roles.union(pvm.role)
security_manager.get_session.delete(pvm) security_manager.get_session.delete(pvm)
first_prm.roles = list(roles) first_prm.roles = list(roles)
security_manager.get_session.commit()
pvms = security_manager.get_session.query( pvms = security_manager.get_session.query(
security_manager.permissionview_model security_manager.permissionview_model
@ -52,7 +52,6 @@ def cleanup_permissions() -> None:
for pvm in pvms: for pvm in pvms:
if not (pvm.view_menu and pvm.permission): if not (pvm.view_menu and pvm.permission):
security_manager.get_session.delete(pvm) security_manager.get_session.delete(pvm)
security_manager.get_session.commit()
pvms = security_manager.get_session.query( pvms = security_manager.get_session.query(
security_manager.permissionview_model security_manager.permissionview_model
@ -63,7 +62,6 @@ def cleanup_permissions() -> None:
roles = security_manager.get_session.query(security_manager.role_model).all() roles = security_manager.get_session.query(security_manager.role_model).all()
for role in roles: for role in roles:
role.permissions = [p for p in role.permissions if p] role.permissions = [p for p in role.permissions if p]
security_manager.get_session.commit()
# 4. Delete empty roles from permission view menus # 4. Delete empty roles from permission view menus
pvms = security_manager.get_session.query( pvms = security_manager.get_session.query(
@ -71,7 +69,6 @@ def cleanup_permissions() -> None:
).all() ).all()
for pvm in pvms: for pvm in pvms:
pvm.role = [r for r in pvm.role if r] pvm.role = [r for r in pvm.role if r]
security_manager.get_session.commit()
cleanup_permissions() cleanup_permissions()

View File

@ -29,6 +29,7 @@ echo "Superset config module: $SUPERSET_CONFIG"
superset db upgrade superset db upgrade
superset init superset init
superset load-test-users
echo "Running tests" echo "Running tests"

View File

@ -113,8 +113,10 @@ class CacheRestApi(BaseSupersetModelRestApi):
delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member
CacheKey.cache_key.in_(cache_keys) CacheKey.cache_key.in_(cache_keys)
) )
db.session.execute(delete_stmt)
db.session.commit() with db.session.begin_nested():
db.session.execute(delete_stmt)
stats_logger_manager.instance.gauge( stats_logger_manager.instance.gauge(
"invalidated_cache", len(cache_keys) "invalidated_cache", len(cache_keys)
) )
@ -125,7 +127,5 @@ class CacheRestApi(BaseSupersetModelRestApi):
) )
except SQLAlchemyError as ex: # pragma: no cover except SQLAlchemyError as ex: # pragma: no cover
logger.error(ex, exc_info=True) logger.error(ex, exc_info=True)
db.session.rollback()
return self.response_500(str(ex)) return self.response_500(str(ex))
db.session.commit()
return self.response(201) return self.response(201)

View File

@ -20,6 +20,7 @@ import click
from flask.cli import with_appcontext from flask.cli import with_appcontext
import superset.utils.database as database_utils import superset.utils.database as database_utils
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -89,6 +90,7 @@ def load_examples_run(
@click.command() @click.command()
@with_appcontext @with_appcontext
@transaction()
@click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data") @click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data")
@click.option("--load-big-data", "-b", is_flag=True, help="Load additional big data") @click.option("--load-big-data", "-b", is_flag=True, help="Load additional big data")
@click.option( @click.option(

View File

@ -27,6 +27,7 @@ from flask.cli import FlaskGroup, with_appcontext
from superset import app, appbuilder, cli, security_manager from superset import app, appbuilder, cli, security_manager
from superset.cli.lib import normalize_token from superset.cli.lib import normalize_token
from superset.extensions import db from superset.extensions import db
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,6 +61,7 @@ for load, module_name, is_pkg in pkgutil.walk_packages(
@superset.command() @superset.command()
@with_appcontext @with_appcontext
@transaction()
def init() -> None: def init() -> None:
"""Inits the Superset application""" """Inits the Superset application"""
appbuilder.add_permissions(update_perms=True) appbuilder.add_permissions(update_perms=True)

View File

@ -22,12 +22,14 @@ from flask.cli import with_appcontext
import superset.utils.database as database_utils import superset.utils.database as database_utils
from superset import app, security_manager from superset import app, security_manager
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@click.command() @click.command()
@with_appcontext @with_appcontext
@transaction()
def load_test_users() -> None: def load_test_users() -> None:
""" """
Loads admin, alpha, and gamma user for testing purposes Loads admin, alpha, and gamma user for testing purposes
@ -35,15 +37,7 @@ def load_test_users() -> None:
Syncs permissions for those users/roles Syncs permissions for those users/roles
""" """
print(Fore.GREEN + "Loading a set of users for unit tests") print(Fore.GREEN + "Loading a set of users for unit tests")
load_test_users_run()
def load_test_users_run() -> None:
"""
Loads admin, alpha, and gamma user for testing purposes
Syncs permissions for those users/roles
"""
if app.config["TESTING"]: if app.config["TESTING"]:
sm = security_manager sm = security_manager
@ -84,4 +78,3 @@ def load_test_users_run() -> None:
sm.find_role(role), sm.find_role(role),
password="general", password="general",
) )
sm.get_session.commit()

View File

@ -30,6 +30,7 @@ from flask_appbuilder.api import BaseApi
from flask_appbuilder.api.manager import resolver from flask_appbuilder.api.manager import resolver
import superset.utils.database as database_utils import superset.utils.database as database_utils
from superset.utils.decorators import transaction
from superset.utils.encrypt import SecretsMigrator from superset.utils.encrypt import SecretsMigrator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,6 +38,7 @@ logger = logging.getLogger(__name__)
@click.command() @click.command()
@with_appcontext @with_appcontext
@transaction()
@click.option("--database_name", "-d", help="Database name to change") @click.option("--database_name", "-d", help="Database name to change")
@click.option("--uri", "-u", help="Database URI to change") @click.option("--uri", "-u", help="Database URI to change")
@click.option( @click.option(
@ -53,6 +55,7 @@ def set_database_uri(database_name: str, uri: str, skip_create: bool) -> None:
@click.command() @click.command()
@with_appcontext @with_appcontext
@transaction()
def sync_tags() -> None: def sync_tags() -> None:
"""Rebuilds special tags (owner, type, favorited by).""" """Rebuilds special tags (owner, type, favorited by)."""
# pylint: disable=no-member # pylint: disable=no-member

View File

@ -16,6 +16,7 @@
# under the License. # under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -30,7 +31,7 @@ from superset.commands.annotation_layer.annotation.exceptions import (
from superset.commands.annotation_layer.exceptions import AnnotationLayerNotFoundError from superset.commands.annotation_layer.exceptions import AnnotationLayerNotFoundError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO
from superset.daos.exceptions import DAOCreateFailedError from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,13 +40,10 @@ class CreateAnnotationCommand(BaseCommand):
def __init__(self, data: dict[str, Any]): def __init__(self, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=AnnotationCreateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
try: return AnnotationDAO.create(attributes=self._properties)
return AnnotationDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from superset.commands.annotation_layer.annotation.exceptions import ( from superset.commands.annotation_layer.annotation.exceptions import (
@ -23,8 +24,8 @@ from superset.commands.annotation_layer.annotation.exceptions import (
) )
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationDAO from superset.daos.annotation_layer import AnnotationDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.models.annotations import Annotation from superset.models.annotations import Annotation
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,15 +35,11 @@ class DeleteAnnotationCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[list[Annotation]] = None self._models: Optional[list[Annotation]] = None
@transaction(on_error=partial(on_error, reraise=AnnotationDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._models assert self._models
AnnotationDAO.delete(self._models)
try:
AnnotationDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise AnnotationDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -16,6 +16,7 @@
# under the License. # under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -31,8 +32,8 @@ from superset.commands.annotation_layer.annotation.exceptions import (
from superset.commands.annotation_layer.exceptions import AnnotationLayerNotFoundError from superset.commands.annotation_layer.exceptions import AnnotationLayerNotFoundError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.models.annotations import Annotation from superset.models.annotations import Annotation
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,16 +44,11 @@ class UpdateAnnotationCommand(BaseCommand):
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[Annotation] = None self._model: Optional[Annotation] = None
@transaction(on_error=partial(on_error, reraise=AnnotationUpdateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
assert self._model assert self._model
return AnnotationDAO.update(self._model, self._properties)
try:
annotation = AnnotationDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationUpdateFailedError() from ex
return annotation
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any from typing import Any
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -27,7 +28,7 @@ from superset.commands.annotation_layer.exceptions import (
) )
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationLayerDAO from superset.daos.annotation_layer import AnnotationLayerDAO
from superset.daos.exceptions import DAOCreateFailedError from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,13 +37,10 @@ class CreateAnnotationLayerCommand(BaseCommand):
def __init__(self, data: dict[str, Any]): def __init__(self, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=AnnotationLayerCreateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
try: return AnnotationLayerDAO.create(attributes=self._properties)
return AnnotationLayerDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationLayerCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from superset.commands.annotation_layer.exceptions import ( from superset.commands.annotation_layer.exceptions import (
@ -24,8 +25,8 @@ from superset.commands.annotation_layer.exceptions import (
) )
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationLayerDAO from superset.daos.annotation_layer import AnnotationLayerDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.models.annotations import AnnotationLayer from superset.models.annotations import AnnotationLayer
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,15 +36,11 @@ class DeleteAnnotationLayerCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[list[AnnotationLayer]] = None self._models: Optional[list[AnnotationLayer]] = None
@transaction(on_error=partial(on_error, reraise=AnnotationLayerDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._models assert self._models
AnnotationLayerDAO.delete(self._models)
try:
AnnotationLayerDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise AnnotationLayerDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -28,8 +29,8 @@ from superset.commands.annotation_layer.exceptions import (
) )
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.daos.annotation_layer import AnnotationLayerDAO from superset.daos.annotation_layer import AnnotationLayerDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.models.annotations import AnnotationLayer from superset.models.annotations import AnnotationLayer
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,16 +41,11 @@ class UpdateAnnotationLayerCommand(BaseCommand):
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[AnnotationLayer] = None self._model: Optional[AnnotationLayer] = None
@transaction(on_error=partial(on_error, reraise=AnnotationLayerUpdateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
assert self._model assert self._model
return AnnotationLayerDAO.update(self._model, self._properties)
try:
annotation_layer = AnnotationLayerDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationLayerUpdateFailedError() from ex
return annotation_layer
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []

View File

@ -16,6 +16,7 @@
# under the License. # under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask import g from flask import g
@ -33,7 +34,7 @@ from superset.commands.chart.exceptions import (
from superset.commands.utils import get_datasource_by_id from superset.commands.utils import get_datasource_by_id
from superset.daos.chart import ChartDAO from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAOCreateFailedError from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,15 +43,12 @@ class CreateChartCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any]): def __init__(self, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=ChartCreateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
try: self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_at"] = datetime.now() self._properties["last_saved_by"] = g.user
self._properties["last_saved_by"] = g.user return ChartDAO.create(attributes=self._properties)
return ChartDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ChartCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
exceptions = [] exceptions = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
@ -28,10 +29,10 @@ from superset.commands.chart.exceptions import (
ChartNotFoundError, ChartNotFoundError,
) )
from superset.daos.chart import ChartDAO from superset.daos.chart import ChartDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO from superset.daos.report import ReportScheduleDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,15 +42,11 @@ class DeleteChartCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[list[Slice]] = None self._models: Optional[list[Slice]] = None
@transaction(on_error=partial(on_error, reraise=ChartDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._models assert self._models
ChartDAO.delete(self._models)
try:
ChartDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise ChartDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -77,7 +77,7 @@ def import_chart(
if chart.id is None: if chart.id is None:
db.session.flush() db.session.flush()
if user := get_user(): if (user := get_user()) and user not in chart.owners:
chart.owners.append(user) chart.owners.append(user)
return chart return chart

View File

@ -16,6 +16,7 @@
# under the License. # under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask import g from flask import g
@ -35,10 +36,10 @@ from superset.commands.chart.exceptions import (
from superset.commands.utils import get_datasource_by_id, update_tags, validate_tags from superset.commands.utils import get_datasource_by_id, update_tags, validate_tags
from superset.daos.chart import ChartDAO from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAODeleteFailedError, DAOUpdateFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.tags.models import ObjectType from superset.tags.models import ObjectType
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,24 +56,20 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[Slice] = None self._model: Optional[Slice] = None
@transaction(on_error=partial(on_error, reraise=ChartUpdateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
assert self._model assert self._model
try: # Update tags
# Update tags if (tags := self._properties.pop("tags", None)) is not None:
tags = self._properties.pop("tags", None) update_tags(ObjectType.chart, self._model.id, self._model.tags, tags)
if tags is not None:
update_tags(ObjectType.chart, self._model.id, self._model.tags, tags)
if self._properties.get("query_context_generation") is None: if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now() self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user self._properties["last_saved_by"] = g.user
chart = ChartDAO.update(self._model, self._properties)
except (DAOUpdateFailedError, DAODeleteFailedError) as ex: return ChartDAO.update(self._model, self._properties)
logger.exception(ex.exception)
raise ChartUpdateFailedError() from ex
return chart
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
@ -23,8 +24,8 @@ from superset.commands.css.exceptions import (
CssTemplateNotFoundError, CssTemplateNotFoundError,
) )
from superset.daos.css import CssTemplateDAO from superset.daos.css import CssTemplateDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.models.core import CssTemplate from superset.models.core import CssTemplate
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,15 +35,11 @@ class DeleteCssTemplateCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[list[CssTemplate]] = None self._models: Optional[list[CssTemplate]] = None
@transaction(on_error=partial(on_error, reraise=CssTemplateDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._models assert self._models
CssTemplateDAO.delete(self._models)
try:
CssTemplateDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise CssTemplateDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -28,23 +29,19 @@ from superset.commands.dashboard.exceptions import (
) )
from superset.commands.utils import populate_roles from superset.commands.utils import populate_roles
from superset.daos.dashboard import DashboardDAO from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAOCreateFailedError from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CreateDashboardCommand(CreateMixin, BaseCommand): class CreateDashboardCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any]): def __init__(self, data: dict[str, Any]) -> None:
self._properties = data.copy() self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DashboardCreateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
try: return DashboardDAO.create(attributes=self._properties)
dashboard = DashboardDAO.create(attributes=self._properties, commit=True)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise DashboardCreateFailedError() from ex
return dashboard
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
@ -28,10 +29,10 @@ from superset.commands.dashboard.exceptions import (
DashboardNotFoundError, DashboardNotFoundError,
) )
from superset.daos.dashboard import DashboardDAO from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO from superset.daos.report import ReportScheduleDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,15 +42,11 @@ class DeleteDashboardCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[list[Dashboard]] = None self._models: Optional[list[Dashboard]] = None
@transaction(on_error=partial(on_error, reraise=DashboardDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._models assert self._models
DashboardDAO.delete(self._models)
try:
DashboardDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DashboardDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -36,6 +36,7 @@ from superset.utils.dashboard_filter_scopes_converter import (
convert_filter_scopes, convert_filter_scopes,
copy_filter_scopes, copy_filter_scopes,
) )
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -311,7 +312,6 @@ def import_dashboards(
for dashboard in data["dashboards"]: for dashboard in data["dashboards"]:
import_dashboard(dashboard, dataset_id_mapping, import_time=import_time) import_dashboard(dashboard, dataset_id_mapping, import_time=import_time)
db.session.commit()
class ImportDashboardsCommand(BaseCommand): class ImportDashboardsCommand(BaseCommand):
@ -329,6 +329,7 @@ class ImportDashboardsCommand(BaseCommand):
self.contents = contents self.contents = contents
self.database_id = database_id self.database_id = database_id
@transaction()
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()

View File

@ -188,7 +188,7 @@ def import_dashboard(
if dashboard.id is None: if dashboard.id is None:
db.session.flush() db.session.flush()
if user := get_user(): if (user := get_user()) and user not in dashboard.owners:
dashboard.owners.append(user) dashboard.owners.append(user)
return dashboard return dashboard

View File

@ -15,18 +15,22 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand
from superset.commands.key_value.upsert import UpsertKeyValueCommand from superset.commands.key_value.upsert import UpsertKeyValueCommand
from superset.daos.dashboard import DashboardDAO from superset.daos.dashboard import DashboardDAO
from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError
from superset.dashboards.permalink.types import DashboardPermalinkState from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.key_value.exceptions import KeyValueCodecEncodeException from superset.key_value.exceptions import (
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
)
from superset.key_value.utils import encode_permalink_key, get_deterministic_uuid from superset.key_value.utils import encode_permalink_key, get_deterministic_uuid
from superset.utils.core import get_user_id from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,29 +51,33 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
self.dashboard_id = dashboard_id self.dashboard_id = dashboard_id
self.state = state self.state = state
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueCodecEncodeException,
KeyValueUpsertFailedError,
SQLAlchemyError,
),
reraise=DashboardPermalinkCreateFailedError,
),
)
def run(self) -> str: def run(self) -> str:
self.validate() self.validate()
try: dashboard = DashboardDAO.get_by_id_or_slug(self.dashboard_id)
dashboard = DashboardDAO.get_by_id_or_slug(self.dashboard_id) value = {
value = { "dashboardId": str(dashboard.uuid),
"dashboardId": str(dashboard.uuid), "state": self.state,
"state": self.state, }
} user_id = get_user_id()
user_id = get_user_id() key = UpsertKeyValueCommand(
key = UpsertKeyValueCommand( resource=self.resource,
resource=self.resource, key=get_deterministic_uuid(self.salt, (user_id, value)),
key=get_deterministic_uuid(self.salt, (user_id, value)), value=value,
value=value, codec=self.codec,
codec=self.codec, ).run()
).run() assert key.id # for type checks
assert key.id # for type checks return encode_permalink_key(key=key.id, salt=self.salt)
db.session.commit()
return encode_permalink_key(key=key.id, salt=self.salt)
except KeyValueCodecEncodeException as ex:
raise DashboardPermalinkCreateFailedError(str(ex)) from ex
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise DashboardPermalinkCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -31,12 +32,11 @@ from superset.commands.dashboard.exceptions import (
) )
from superset.commands.utils import populate_roles, update_tags, validate_tags from superset.commands.utils import populate_roles, update_tags, validate_tags
from superset.daos.dashboard import DashboardDAO from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAODeleteFailedError, DAOUpdateFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.extensions import db
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.tags.models import ObjectType from superset.tags.models import ObjectType
from superset.utils import json from superset.utils import json
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,29 +47,22 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[Dashboard] = None self._model: Optional[Dashboard] = None
@transaction(on_error=partial(on_error, reraise=DashboardUpdateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
assert self._model assert self._model
try: # Update tags
# Update tags if (tags := self._properties.pop("tags", None)) is not None:
tags = self._properties.pop("tags", None) update_tags(ObjectType.dashboard, self._model.id, self._model.tags, tags)
if tags is not None:
update_tags( dashboard = DashboardDAO.update(self._model, self._properties)
ObjectType.dashboard, self._model.id, self._model.tags, tags if self._properties.get("json_metadata"):
) DashboardDAO.set_dash_metadata(
dashboard,
data=json.loads(self._properties.get("json_metadata", "{}")),
)
dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
if self._properties.get("json_metadata"):
dashboard = DashboardDAO.set_dash_metadata(
dashboard,
data=json.loads(self._properties.get("json_metadata", "{}")),
commit=False,
)
db.session.commit()
except (DAOUpdateFailedError, DAODeleteFailedError) as ex:
logger.exception(ex.exception)
raise DashboardUpdateFailedError() from ex
return dashboard return dashboard
def validate(self) -> None: def validate(self) -> None:

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask import current_app from flask import current_app
@ -39,11 +40,11 @@ from superset.commands.database.ssh_tunnel.exceptions import (
) )
from superset.commands.database.test_connection import TestConnectionDatabaseCommand from superset.commands.database.test_connection import TestConnectionDatabaseCommand
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.exceptions import SupersetErrorsException from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager from superset.extensions import event_logger, security_manager
from superset.models.core import Database from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
stats_logger = current_app.config["STATS_LOGGER"] stats_logger = current_app.config["STATS_LOGGER"]
@ -53,6 +54,7 @@ class CreateDatabaseCommand(BaseCommand):
def __init__(self, data: dict[str, Any]): def __init__(self, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
@ -96,8 +98,6 @@ class CreateDatabaseCommand(BaseCommand):
database, ssh_tunnel_properties database, ssh_tunnel_properties
).run() ).run()
db.session.commit()
# add catalog/schema permissions # add catalog/schema permissions
if database.db_engine_spec.supports_catalog: if database.db_engine_spec.supports_catalog:
catalogs = database.get_all_catalog_names( catalogs = database.get_all_catalog_names(
@ -121,14 +121,12 @@ class CreateDatabaseCommand(BaseCommand):
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
logger.warning("Error processing catalog '%s'", catalog) logger.warning("Error processing catalog '%s'", catalog)
continue continue
except ( except (
SSHTunnelInvalidError, SSHTunnelInvalidError,
SSHTunnelCreateFailedError, SSHTunnelCreateFailedError,
SSHTunnelingNotEnabledError, SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError, SSHTunnelDatabasePortError,
) as ex: ) as ex:
db.session.rollback()
event_logger.log_with_context( event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
@ -136,11 +134,9 @@ class CreateDatabaseCommand(BaseCommand):
# So we can show the original message # So we can show the original message
raise raise
except ( except (
DAOCreateFailedError,
DatabaseInvalidError, DatabaseInvalidError,
Exception, Exception,
) as ex: ) as ex:
db.session.rollback()
event_logger.log_with_context( event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}", action=f"db_creation_failed.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__, engine=database.db_engine_spec.__name__,
@ -198,6 +194,6 @@ class CreateDatabaseCommand(BaseCommand):
raise exception raise exception
def _create_database(self) -> Database: def _create_database(self) -> Database:
database = DatabaseDAO.create(attributes=self._properties, commit=False) database = DatabaseDAO.create(attributes=self._properties)
database.set_sqlalchemy_uri(database.sqlalchemy_uri) database.set_sqlalchemy_uri(database.sqlalchemy_uri)
return database return database

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
@ -27,9 +28,9 @@ from superset.commands.database.exceptions import (
DatabaseNotFoundError, DatabaseNotFoundError,
) )
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO from superset.daos.report import ReportScheduleDAO
from superset.models.core import Database from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,15 +40,11 @@ class DeleteDatabaseCommand(BaseCommand):
self._model_id = model_id self._model_id = model_id
self._model: Optional[Database] = None self._model: Optional[Database] = None
@transaction(on_error=partial(on_error, reraise=DatabaseDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._model assert self._model
DatabaseDAO.delete([self._model])
try:
DatabaseDAO.delete([self._model])
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatabaseDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -28,10 +29,10 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelRequiredFieldValidationError, SSHTunnelRequiredFieldValidationError,
) )
from superset.daos.database import SSHTunnelDAO from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.databases.utils import make_url_safe from superset.databases.utils import make_url_safe
from superset.extensions import event_logger from superset.extensions import event_logger
from superset.models.core import Database from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,6 +45,7 @@ class CreateSSHTunnelCommand(BaseCommand):
self._properties["database"] = database self._properties["database"] = database
self._database = database self._database = database
@transaction(on_error=partial(on_error, reraise=SSHTunnelCreateFailedError))
def run(self) -> Model: def run(self) -> Model:
""" """
Create an SSH tunnel. Create an SSH tunnel.
@ -53,11 +55,8 @@ class CreateSSHTunnelCommand(BaseCommand):
:raises SSHTunnelInvalidError: If the configuration are invalid :raises SSHTunnelInvalidError: If the configuration are invalid
""" """
try: self.validate()
self.validate() return SSHTunnelDAO.create(attributes=self._properties)
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
except DAOCreateFailedError as ex:
raise SSHTunnelCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost # TODO(hughhh): check to make sure the server port is not localhost

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from superset import is_feature_enabled from superset import is_feature_enabled
@ -25,8 +26,8 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelNotFoundError, SSHTunnelNotFoundError,
) )
from superset.daos.database import SSHTunnelDAO from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,16 +37,13 @@ class DeleteSSHTunnelCommand(BaseCommand):
self._model_id = model_id self._model_id = model_id
self._model: Optional[SSHTunnel] = None self._model: Optional[SSHTunnel] = None
@transaction(on_error=partial(on_error, reraise=SSHTunnelDeleteFailedError))
def run(self) -> None: def run(self) -> None:
if not is_feature_enabled("SSH_TUNNELING"): if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError() raise SSHTunnelingNotEnabledError()
self.validate() self.validate()
assert self._model assert self._model
SSHTunnelDAO.delete([self._model])
try:
SSHTunnelDAO.delete([self._model])
except DAODeleteFailedError as ex:
raise SSHTunnelDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -28,9 +29,9 @@ from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelUpdateFailedError, SSHTunnelUpdateFailedError,
) )
from superset.daos.database import SSHTunnelDAO from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe from superset.databases.utils import make_url_safe
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,25 +42,23 @@ class UpdateSSHTunnelCommand(BaseCommand):
self._model_id = model_id self._model_id = model_id
self._model: Optional[SSHTunnel] = None self._model: Optional[SSHTunnel] = None
@transaction(on_error=partial(on_error, reraise=SSHTunnelUpdateFailedError))
def run(self) -> Optional[Model]: def run(self) -> Optional[Model]:
self.validate() self.validate()
try:
if self._model is None:
return None
# unset password if private key is provided if self._model is None:
if self._properties.get("private_key"): return None
self._properties["password"] = None
# unset private key and password if password is provided # unset password if private key is provided
if self._properties.get("password"): if self._properties.get("private_key"):
self._properties["private_key"] = None self._properties["password"] = None
self._properties["private_key_password"] = None
tunnel = SSHTunnelDAO.update(self._model, self._properties) # unset private key and password if password is provided
return tunnel if self._properties.get("password"):
except DAOUpdateFailedError as ex: self._properties["private_key"] = None
raise SSHTunnelUpdateFailedError() from ex self._properties["private_key_password"] = None
return SSHTunnelDAO.update(self._model, self._properties)
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -18,6 +18,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from functools import partial
from typing import Any from typing import Any
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -34,16 +35,14 @@ from superset.commands.database.exceptions import (
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.commands.database.ssh_tunnel.exceptions import ( from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelError,
SSHTunnelingNotEnabledError, SSHTunnelingNotEnabledError,
) )
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.extensions import db
from superset.models.core import Database from superset.models.core import Database
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,6 +55,7 @@ class UpdateDatabaseCommand(BaseCommand):
self._model_id = model_id self._model_id = model_id
self._model: Database | None = None self._model: Database | None = None
@transaction(on_error=partial(on_error, reraise=DatabaseUpdateFailedError))
def run(self) -> Model: def run(self) -> Model:
self._model = DatabaseDAO.find_by_id(self._model_id) self._model = DatabaseDAO.find_by_id(self._model_id)
@ -76,21 +76,10 @@ class UpdateDatabaseCommand(BaseCommand):
# since they're name based # since they're name based
original_database_name = self._model.database_name original_database_name = self._model.database_name
try: database = DatabaseDAO.update(self._model, self._properties)
database = DatabaseDAO.update( database.set_sqlalchemy_uri(database.sqlalchemy_uri)
self._model, ssh_tunnel = self._handle_ssh_tunnel(database)
self._properties, self._refresh_catalogs(database, original_database_name, ssh_tunnel)
commit=False,
)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = self._handle_ssh_tunnel(database)
self._refresh_catalogs(database, original_database_name, ssh_tunnel)
except SSHTunnelError: # pylint: disable=try-except-raise
# allow exception to bubble for debugbing information
raise
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
raise DatabaseUpdateFailedError() from ex
return database return database
def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None: def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None:
@ -101,7 +90,6 @@ class UpdateDatabaseCommand(BaseCommand):
return None return None
if not is_feature_enabled("SSH_TUNNELING"): if not is_feature_enabled("SSH_TUNNELING"):
db.session.rollback()
raise SSHTunnelingNotEnabledError() raise SSHTunnelingNotEnabledError()
current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id)
@ -131,13 +119,13 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support. from any of the 50+ database drivers we support.
""" """
try: try:
return database.get_all_catalog_names( return database.get_all_catalog_names(
force=True, force=True,
ssh_tunnel=ssh_tunnel, ssh_tunnel=ssh_tunnel,
) )
except Exception as ex: except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex raise DatabaseConnectionFailedError() from ex
def _get_schema_names( def _get_schema_names(
@ -152,6 +140,7 @@ class UpdateDatabaseCommand(BaseCommand):
This method captures a generic exception, since errors could potentially come This method captures a generic exception, since errors could potentially come
from any of the 50+ database drivers we support. from any of the 50+ database drivers we support.
""" """
try: try:
return database.get_all_schema_names( return database.get_all_schema_names(
force=True, force=True,
@ -159,7 +148,6 @@ class UpdateDatabaseCommand(BaseCommand):
ssh_tunnel=ssh_tunnel, ssh_tunnel=ssh_tunnel,
) )
except Exception as ex: except Exception as ex:
db.session.rollback()
raise DatabaseConnectionFailedError() from ex raise DatabaseConnectionFailedError() from ex
def _refresh_catalogs( def _refresh_catalogs(
@ -225,8 +213,6 @@ class UpdateDatabaseCommand(BaseCommand):
schemas, schemas,
) )
db.session.commit()
def _refresh_schemas( def _refresh_schemas(
self, self,
database: Database, database: Database,

View File

@ -16,11 +16,11 @@
# under the License. # under the License.
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from functools import partial
from typing import Any, Optional, TypedDict from typing import Any, Optional, TypedDict
import pandas as pd import pandas as pd
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from sqlalchemy.exc import SQLAlchemyError
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from superset import db from superset import db
@ -37,6 +37,7 @@ from superset.daos.database import DatabaseDAO
from superset.models.core import Database from superset.models.core import Database
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.utils.core import get_user from superset.utils.core import get_user
from superset.utils.decorators import on_error, transaction
from superset.views.database.validators import schema_allows_file_upload from superset.views.database.validators import schema_allows_file_upload
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -144,6 +145,7 @@ class UploadCommand(BaseCommand):
self._file = file self._file = file
self._reader = reader self._reader = reader
@transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
if not self._model: if not self._model:
@ -172,12 +174,6 @@ class UploadCommand(BaseCommand):
sqla_table.fetch_metadata() sqla_table.fetch_metadata()
try:
db.session.commit()
except SQLAlchemyError as ex:
db.session.rollback()
raise DatabaseUploadSaveMetadataFailed() from ex
def validate(self) -> None: def validate(self) -> None:
self._model = DatabaseDAO.find_by_id(self._model_id) self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model: if not self._model:

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from superset import security_manager from superset import security_manager
@ -26,8 +27,8 @@ from superset.commands.dataset.columns.exceptions import (
) )
from superset.connectors.sqla.models import TableColumn from superset.connectors.sqla.models import TableColumn
from superset.daos.dataset import DatasetColumnDAO, DatasetDAO from superset.daos.dataset import DatasetColumnDAO, DatasetDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,15 +39,11 @@ class DeleteDatasetColumnCommand(BaseCommand):
self._model_id = model_id self._model_id = model_id
self._model: Optional[TableColumn] = None self._model: Optional[TableColumn] = None
@transaction(on_error=partial(on_error, reraise=DatasetColumnDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._model assert self._model
DatasetColumnDAO.delete([self._model])
try:
DatasetColumnDAO.delete([self._model])
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatasetColumnDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -15,11 +15,11 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand, CreateMixin from superset.commands.base import BaseCommand, CreateMixin
from superset.commands.dataset.exceptions import ( from superset.commands.dataset.exceptions import (
@ -31,10 +31,10 @@ from superset.commands.dataset.exceptions import (
TableNotFoundValidationError, TableNotFoundValidationError,
) )
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.extensions import db, security_manager from superset.extensions import security_manager
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,19 +43,12 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any]): def __init__(self, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DatasetCreateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
try:
# Creates SqlaTable (Dataset)
dataset = DatasetDAO.create(attributes=self._properties, commit=False)
# Updates columns and metrics from the dataset dataset = DatasetDAO.create(attributes=self._properties)
dataset.fetch_metadata(commit=False) dataset.fetch_metadata()
db.session.commit()
except (SQLAlchemyError, DAOCreateFailedError) as ex:
logger.warning(ex, exc_info=True)
db.session.rollback()
raise DatasetCreateFailedError() from ex
return dataset return dataset
def validate(self) -> None: def validate(self) -> None:

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from superset import security_manager from superset import security_manager
@ -26,8 +27,8 @@ from superset.commands.dataset.exceptions import (
) )
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,15 +38,11 @@ class DeleteDatasetCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[list[SqlaTable]] = None self._models: Optional[list[SqlaTable]] = None
@transaction(on_error=partial(on_error, reraise=DatasetDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._models assert self._models
DatasetDAO.delete(self._models)
try:
DatasetDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatasetDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -15,12 +15,12 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any from typing import Any
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_babel import gettext as __ from flask_babel import gettext as __
from marshmallow import ValidationError from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand, CreateMixin from superset.commands.base import BaseCommand, CreateMixin
from superset.commands.dataset.exceptions import ( from superset.commands.dataset.exceptions import (
@ -32,12 +32,12 @@ from superset.commands.dataset.exceptions import (
from superset.commands.exceptions import DatasourceTypeInvalidError from superset.commands.exceptions import DatasourceTypeInvalidError
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException from superset.exceptions import SupersetErrorException
from superset.extensions import db from superset.extensions import db
from superset.models.core import Database from superset.models.core import Database
from superset.sql_parse import ParsedQuery, Table from superset.sql_parse import ParsedQuery, Table
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,66 +47,61 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
self._base_model: SqlaTable = SqlaTable() self._base_model: SqlaTable = SqlaTable()
self._properties = data.copy() self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=DatasetDuplicateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
try: database_id = self._base_model.database_id
database_id = self._base_model.database_id table_name = self._properties["table_name"]
table_name = self._properties["table_name"] owners = self._properties["owners"]
owners = self._properties["owners"] database = db.session.query(Database).get(database_id)
database = db.session.query(Database).get(database_id) if not database:
if not database: raise SupersetErrorException(
raise SupersetErrorException( SupersetError(
SupersetError( message=__("The database was not found."),
message=__("The database was not found."), error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR, level=ErrorLevel.ERROR,
level=ErrorLevel.ERROR, ),
), status=404,
status=404, )
) table = SqlaTable(table_name=table_name, owners=owners)
table = SqlaTable(table_name=table_name, owners=owners) table.database = database
table.database = database table.schema = self._base_model.schema
table.schema = self._base_model.schema table.template_params = self._base_model.template_params
table.template_params = self._base_model.template_params table.normalize_columns = self._base_model.normalize_columns
table.normalize_columns = self._base_model.normalize_columns table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm table.is_sqllab_view = True
table.is_sqllab_view = True table.sql = ParsedQuery(
table.sql = ParsedQuery( self._base_model.sql,
self._base_model.sql, engine=database.db_engine_spec.engine,
engine=database.db_engine_spec.engine, ).stripped()
).stripped() db.session.add(table)
db.session.add(table) cols = []
cols = [] for config_ in self._base_model.columns:
for config_ in self._base_model.columns: column_name = config_.column_name
column_name = config_.column_name col = TableColumn(
col = TableColumn( column_name=column_name,
column_name=column_name, verbose_name=config_.verbose_name,
verbose_name=config_.verbose_name, expression=config_.expression,
expression=config_.expression, filterable=True,
filterable=True, groupby=True,
groupby=True, is_dttm=config_.is_dttm,
is_dttm=config_.is_dttm, type=config_.type,
type=config_.type, description=config_.description,
description=config_.description, )
) cols.append(col)
cols.append(col) table.columns = cols
table.columns = cols mets = []
mets = [] for config_ in self._base_model.metrics:
for config_ in self._base_model.metrics: metric_name = config_.metric_name
metric_name = config_.metric_name met = SqlMetric(
met = SqlMetric( metric_name=metric_name,
metric_name=metric_name, verbose_name=config_.verbose_name,
verbose_name=config_.verbose_name, expression=config_.expression,
expression=config_.expression, metric_type=config_.metric_type,
metric_type=config_.metric_type, description=config_.description,
description=config_.description, )
) mets.append(met)
mets.append(met) table.metrics = mets
table.metrics = mets
db.session.commit()
except (SQLAlchemyError, DAOCreateFailedError) as ex:
logger.warning(ex, exc_info=True)
db.session.rollback()
raise DatasetDuplicateFailedError() from ex
return table return table
def validate(self) -> None: def validate(self) -> None:

View File

@ -34,6 +34,7 @@ from superset.connectors.sqla.models import (
) )
from superset.models.core import Database from superset.models.core import Database
from superset.utils import json from superset.utils import json
from superset.utils.decorators import transaction
from superset.utils.dict_import_export import DATABASES_KEY from superset.utils.dict_import_export import DATABASES_KEY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -211,7 +212,6 @@ def import_from_dict(data: dict[str, Any], sync: Optional[list[str]] = None) ->
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []): for database in data.get(DATABASES_KEY, []):
Database.import_from_dict(database, sync=sync) Database.import_from_dict(database, sync=sync)
db.session.commit()
else: else:
logger.info("Supplied object is not a dictionary.") logger.info("Supplied object is not a dictionary.")
@ -240,10 +240,10 @@ class ImportDatasetsCommand(BaseCommand):
if kwargs.get("sync_metrics"): if kwargs.get("sync_metrics"):
self.sync.append("metrics") self.sync.append("metrics")
@transaction()
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
# TODO (betodealmeida): add rollback in case of error
for file_name, config in self._configs.items(): for file_name, config in self._configs.items():
logger.info("Importing dataset from file %s", file_name) logger.info("Importing dataset from file %s", file_name)
if isinstance(config, dict): if isinstance(config, dict):
@ -260,7 +260,6 @@ class ImportDatasetsCommand(BaseCommand):
) )
dataset["database_id"] = database.id dataset["database_id"] = database.id
SqlaTable.import_from_dict(dataset, sync=self.sync) SqlaTable.import_from_dict(dataset, sync=self.sync)
db.session.commit()
def validate(self) -> None: def validate(self) -> None:
# ensure all files are YAML # ensure all files are YAML

View File

@ -178,7 +178,7 @@ def import_dataset(
if data_uri and (not table_exists or force_data): if data_uri and (not table_exists or force_data):
load_data(data_uri, dataset, dataset.database) load_data(data_uri, dataset, dataset.database)
if user := get_user(): if (user := get_user()) and user not in dataset.owners:
dataset.owners.append(user) dataset.owners.append(user)
return dataset return dataset

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from superset import security_manager from superset import security_manager
@ -26,8 +27,8 @@ from superset.commands.dataset.metrics.exceptions import (
) )
from superset.connectors.sqla.models import SqlMetric from superset.connectors.sqla.models import SqlMetric
from superset.daos.dataset import DatasetDAO, DatasetMetricDAO from superset.daos.dataset import DatasetDAO, DatasetMetricDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,15 +39,11 @@ class DeleteDatasetMetricCommand(BaseCommand):
self._model_id = model_id self._model_id = model_id
self._model: Optional[SqlMetric] = None self._model: Optional[SqlMetric] = None
@transaction(on_error=partial(on_error, reraise=DatasetMetricDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._model assert self._model
DatasetMetricDAO.delete([self._model])
try:
DatasetMetricDAO.delete([self._model])
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatasetMetricDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -29,6 +30,7 @@ from superset.commands.dataset.exceptions import (
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,16 +40,12 @@ class RefreshDatasetCommand(BaseCommand):
self._model_id = model_id self._model_id = model_id
self._model: Optional[SqlaTable] = None self._model: Optional[SqlaTable] = None
@transaction(on_error=partial(on_error, reraise=DatasetRefreshFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
if self._model: assert self._model
try: self._model.fetch_metadata()
self._model.fetch_metadata() return self._model
return self._model
except Exception as ex:
logger.exception(ex)
raise DatasetRefreshFailedError() from ex
raise DatasetRefreshFailedError()
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -16,10 +16,12 @@
# under the License. # under the License.
import logging import logging
from collections import Counter from collections import Counter
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset import security_manager from superset import security_manager
from superset.commands.base import BaseCommand, UpdateMixin from superset.commands.base import BaseCommand, UpdateMixin
@ -39,9 +41,9 @@ from superset.commands.dataset.exceptions import (
) )
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,19 +61,20 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
self.override_columns = override_columns self.override_columns = override_columns
self._properties["override_columns"] = override_columns self._properties["override_columns"] = override_columns
@transaction(
on_error=partial(
on_error,
catches=(
SQLAlchemyError,
ValueError,
),
reraise=DatasetUpdateFailedError,
)
)
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
if self._model: assert self._model
try: return DatasetDAO.update(self._model, attributes=self._properties)
dataset = DatasetDAO.update(
self._model,
attributes=self._properties,
)
return dataset
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise DatasetUpdateFailedError() from ex
raise DatasetUpdateFailedError()
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []

View File

@ -15,18 +15,22 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand
from superset.commands.key_value.create import CreateKeyValueCommand from superset.commands.key_value.create import CreateKeyValueCommand
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
from superset.explore.utils import check_access as check_chart_access from superset.explore.utils import check_access as check_chart_access
from superset.key_value.exceptions import KeyValueCodecEncodeException from superset.key_value.exceptions import (
KeyValueCodecEncodeException,
KeyValueCreateFailedError,
)
from superset.key_value.utils import encode_permalink_key from superset.key_value.utils import encode_permalink_key
from superset.utils.core import DatasourceType from superset.utils.core import DatasourceType
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,35 +41,39 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
self.datasource: str = state["formData"]["datasource"] self.datasource: str = state["formData"]["datasource"]
self.state = state self.state = state
@transaction(
on_error=partial(
on_error,
catches=(
KeyValueCodecEncodeException,
KeyValueCreateFailedError,
SQLAlchemyError,
),
reraise=ExplorePermalinkCreateFailedError,
),
)
def run(self) -> str: def run(self) -> str:
self.validate() self.validate()
try: d_id, d_type = self.datasource.split("__")
d_id, d_type = self.datasource.split("__") datasource_id = int(d_id)
datasource_id = int(d_id) datasource_type = DatasourceType(d_type)
datasource_type = DatasourceType(d_type) check_chart_access(datasource_id, self.chart_id, datasource_type)
check_chart_access(datasource_id, self.chart_id, datasource_type) value = {
value = { "chartId": self.chart_id,
"chartId": self.chart_id, "datasourceId": datasource_id,
"datasourceId": datasource_id, "datasourceType": datasource_type.value,
"datasourceType": datasource_type.value, "datasource": self.datasource,
"datasource": self.datasource, "state": self.state,
"state": self.state, }
} command = CreateKeyValueCommand(
command = CreateKeyValueCommand( resource=self.resource,
resource=self.resource, value=value,
value=value, codec=self.codec,
codec=self.codec, )
) key = command.run()
key = command.run() if key.id is None:
if key.id is None: raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
raise ExplorePermalinkCreateFailedError("Unexpected missing key id") return encode_permalink_key(key=key.id, salt=self.salt)
db.session.commit()
return encode_permalink_key(key=key.id, salt=self.salt)
except KeyValueCodecEncodeException as ex:
raise ExplorePermalinkCreateFailedError(str(ex)) from ex
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise ExplorePermalinkCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -32,6 +32,7 @@ from superset.commands.importers.v1.utils import (
) )
from superset.daos.base import BaseDAO from superset.daos.base import BaseDAO
from superset.models.core import Database # noqa: F401 from superset.models.core import Database # noqa: F401
from superset.utils.decorators import transaction
class ImportModelsCommand(BaseCommand): class ImportModelsCommand(BaseCommand):
@ -67,18 +68,15 @@ class ImportModelsCommand(BaseCommand):
def _get_uuids(cls) -> set[str]: def _get_uuids(cls) -> set[str]:
return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()} return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()}
@transaction()
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
# rollback to prevent partial imports
try: try:
self._import(self._configs, self.overwrite) self._import(self._configs, self.overwrite)
db.session.commit()
except CommandException: except CommandException:
db.session.rollback()
raise raise
except Exception as ex: except Exception as ex:
db.session.rollback()
raise self.import_error() from ex raise self.import_error() from ex
def validate(self) -> None: # noqa: F811 def validate(self) -> None: # noqa: F811

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from marshmallow import Schema from marshmallow import Schema
@ -44,6 +45,7 @@ from superset.datasets.schemas import ImportV1DatasetSchema
from superset.migrations.shared.native_filters import migrate_dashboard from superset.migrations.shared.native_filters import migrate_dashboard
from superset.models.dashboard import dashboard_slices from superset.models.dashboard import dashboard_slices
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
from superset.utils.decorators import on_error, transaction
class ImportAssetsCommand(BaseCommand): class ImportAssetsCommand(BaseCommand):
@ -153,16 +155,16 @@ class ImportAssetsCommand(BaseCommand):
if chart.viz_type == "filter_box": if chart.viz_type == "filter_box":
db.session.delete(chart) db.session.delete(chart)
@transaction(
on_error=partial(
on_error,
catches=(Exception,),
reraise=ImportFailedError,
)
)
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
self._import(self._configs)
# rollback to prevent partial imports
try:
self._import(self._configs)
db.session.commit()
except Exception as ex:
db.session.rollback()
raise ImportFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
exceptions: list[ValidationError] = [] exceptions: list[ValidationError] = []

View File

@ -43,6 +43,7 @@ from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.dashboard import dashboard_slices from superset.models.dashboard import dashboard_slices
from superset.utils.core import get_example_default_schema from superset.utils.core import get_example_default_schema
from superset.utils.database import get_example_database from superset.utils.database import get_example_database
from superset.utils.decorators import transaction
class ImportExamplesCommand(ImportModelsCommand): class ImportExamplesCommand(ImportModelsCommand):
@ -62,19 +63,17 @@ class ImportExamplesCommand(ImportModelsCommand):
super().__init__(contents, *args, **kwargs) super().__init__(contents, *args, **kwargs)
self.force_data = kwargs.get("force_data", False) self.force_data = kwargs.get("force_data", False)
@transaction()
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
# rollback to prevent partial imports
try: try:
self._import( self._import(
self._configs, self._configs,
self.overwrite, self.overwrite,
self.force_data, self.force_data,
) )
db.session.commit()
except Exception as ex: except Exception as ex:
db.session.rollback()
raise self.import_error() from ex raise self.import_error() from ex
@classmethod @classmethod

View File

@ -16,17 +16,17 @@
# under the License. # under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.models import KeyValueEntry from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.utils.core import get_user_id from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,6 +62,7 @@ class CreateKeyValueCommand(BaseCommand):
self.key = key self.key = key
self.expires_on = expires_on self.expires_on = expires_on
@transaction(on_error=partial(on_error, reraise=KeyValueCreateFailedError))
def run(self) -> Key: def run(self) -> Key:
""" """
Persist the value Persist the value
@ -69,11 +70,8 @@ class CreateKeyValueCommand(BaseCommand):
:return: the key associated with the persisted value :return: the key associated with the persisted value
""" """
try:
return self.create() return self.create()
except SQLAlchemyError as ex:
db.session.rollback()
raise KeyValueCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass
@ -98,6 +96,7 @@ class CreateKeyValueCommand(BaseCommand):
entry.id = self.key entry.id = self.key
except ValueError as ex: except ValueError as ex:
raise KeyValueCreateFailedError() from ex raise KeyValueCreateFailedError() from ex
db.session.add(entry) db.session.add(entry)
db.session.flush() db.session.flush()
return Key(id=entry.id, uuid=entry.uuid) return Key(id=entry.id, uuid=entry.uuid)

View File

@ -15,17 +15,17 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Union from typing import Union
from uuid import UUID from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueDeleteFailedError from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.key_value.models import KeyValueEntry from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource from superset.key_value.types import KeyValueResource
from superset.key_value.utils import get_filter from superset.key_value.utils import get_filter
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,20 +45,19 @@ class DeleteKeyValueCommand(BaseCommand):
self.resource = resource self.resource = resource
self.key = key self.key = key
@transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError))
def run(self) -> bool: def run(self) -> bool:
try: return self.delete()
return self.delete()
except SQLAlchemyError as ex:
db.session.rollback()
raise KeyValueDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass
def delete(self) -> bool: def delete(self) -> bool:
filter_ = get_filter(self.resource, self.key) if (
if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first(): entry := db.session.query(KeyValueEntry)
.filter_by(**get_filter(self.resource, self.key))
.first()
):
db.session.delete(entry) db.session.delete(entry)
db.session.flush()
return True return True
return False return False

View File

@ -16,15 +16,16 @@
# under the License. # under the License.
import logging import logging
from datetime import datetime from datetime import datetime
from functools import partial
from sqlalchemy import and_ from sqlalchemy import and_
from sqlalchemy.exc import SQLAlchemyError
from superset import db from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueDeleteFailedError from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.key_value.models import KeyValueEntry from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource from superset.key_value.types import KeyValueResource
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,12 +42,9 @@ class DeleteExpiredKeyValueCommand(BaseCommand):
""" """
self.resource = resource self.resource = resource
@transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError))
def run(self) -> None: def run(self) -> None:
try: self.delete_expired()
self.delete_expired()
except SQLAlchemyError as ex:
db.session.rollback()
raise KeyValueDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass
@ -62,4 +60,3 @@ class DeleteExpiredKeyValueCommand(BaseCommand):
) )
.delete() .delete()
) )
db.session.flush()

View File

@ -17,11 +17,10 @@
import logging import logging
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
from superset import db from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueUpdateFailedError from superset.key_value.exceptions import KeyValueUpdateFailedError
@ -29,6 +28,7 @@ from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.key_value.utils import get_filter from superset.key_value.utils import get_filter
from superset.utils.core import get_user_id from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,12 +64,9 @@ class UpdateKeyValueCommand(BaseCommand):
self.codec = codec self.codec = codec
self.expires_on = expires_on self.expires_on = expires_on
@transaction(on_error=partial(on_error, reraise=KeyValueUpdateFailedError))
def run(self) -> Optional[Key]: def run(self) -> Optional[Key]:
try: return self.update()
return self.update()
except SQLAlchemyError as ex:
db.session.rollback()
raise KeyValueUpdateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -17,6 +17,7 @@
import logging import logging
from datetime import datetime from datetime import datetime
from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
@ -33,6 +34,7 @@ from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.key_value.types import Key, KeyValueCodec, KeyValueResource
from superset.key_value.utils import get_filter from superset.key_value.utils import get_filter
from superset.utils.core import get_user_id from superset.utils.core import get_user_id
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,27 +70,29 @@ class UpsertKeyValueCommand(BaseCommand):
self.codec = codec self.codec = codec
self.expires_on = expires_on self.expires_on = expires_on
@transaction(
on_error=partial(
on_error,
catches=(KeyValueCreateFailedError, SQLAlchemyError),
reraise=KeyValueUpsertFailedError,
),
)
def run(self) -> Key: def run(self) -> Key:
try: return self.upsert()
return self.upsert()
except (KeyValueCreateFailedError, SQLAlchemyError) as ex:
db.session.rollback()
raise KeyValueUpsertFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass
def upsert(self) -> Key: def upsert(self) -> Key:
filter_ = get_filter(self.resource, self.key) if (
entry: KeyValueEntry = ( entry := db.session.query(KeyValueEntry)
db.session.query(KeyValueEntry).filter_by(**filter_).first() .filter_by(**get_filter(self.resource, self.key))
) .first()
if entry: ):
entry.value = self.codec.encode(self.value) entry.value = self.codec.encode(self.value)
entry.expires_on = self.expires_on entry.expires_on = self.expires_on
entry.changed_on = datetime.now() entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id() entry.changed_by_fk = get_user_id()
db.session.flush()
return Key(entry.id, entry.uuid) return Key(entry.id, entry.uuid)
return CreateKeyValueCommand( return CreateKeyValueCommand(

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
@ -22,9 +23,9 @@ from superset.commands.query.exceptions import (
SavedQueryDeleteFailedError, SavedQueryDeleteFailedError,
SavedQueryNotFoundError, SavedQueryNotFoundError,
) )
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.query import SavedQueryDAO from superset.daos.query import SavedQueryDAO
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,15 +35,11 @@ class DeleteSavedQueryCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[list[Dashboard]] = None self._models: Optional[list[Dashboard]] = None
@transaction(on_error=partial(on_error, reraise=SavedQueryDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._models assert self._models
SavedQueryDAO.delete(self._models)
try:
SavedQueryDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise SavedQueryDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_babel import gettext as _ from flask_babel import gettext as _
@ -31,7 +32,6 @@ from superset.commands.report.exceptions import (
ReportScheduleNameUniquenessValidationError, ReportScheduleNameUniquenessValidationError,
) )
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.daos.report import ReportScheduleDAO from superset.daos.report import ReportScheduleDAO
from superset.reports.models import ( from superset.reports.models import (
ReportCreationMethod, ReportCreationMethod,
@ -40,6 +40,7 @@ from superset.reports.models import (
) )
from superset.reports.types import ReportScheduleExtra from superset.reports.types import ReportScheduleExtra
from superset.utils import json from superset.utils import json
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,13 +49,10 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
def __init__(self, data: dict[str, Any]): def __init__(self, data: dict[str, Any]):
self._properties = data.copy() self._properties = data.copy()
@transaction(on_error=partial(on_error, reraise=ReportScheduleCreateFailedError))
def run(self) -> ReportSchedule: def run(self) -> ReportSchedule:
self.validate() self.validate()
try: return ReportScheduleDAO.create(attributes=self._properties)
return ReportScheduleDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ReportScheduleCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
""" """

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Optional from typing import Optional
from superset import security_manager from superset import security_manager
@ -24,10 +25,10 @@ from superset.commands.report.exceptions import (
ReportScheduleForbiddenError, ReportScheduleForbiddenError,
ReportScheduleNotFoundError, ReportScheduleNotFoundError,
) )
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO from superset.daos.report import ReportScheduleDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.reports.models import ReportSchedule from superset.reports.models import ReportSchedule
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,15 +38,11 @@ class DeleteReportScheduleCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[list[ReportSchedule]] = None self._models: Optional[list[ReportSchedule]] = None
@transaction(on_error=partial(on_error, reraise=ReportScheduleDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
assert self._models assert self._models
ReportScheduleDAO.delete(self._models)
try:
ReportScheduleDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise ReportScheduleDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -69,7 +69,7 @@ from superset.tasks.utils import get_executor
from superset.utils import json from superset.utils import json
from superset.utils.core import HeaderDataType, override_user from superset.utils.core import HeaderDataType, override_user
from superset.utils.csv import get_chart_csv_data, get_chart_dataframe from superset.utils.csv import get_chart_csv_data, get_chart_dataframe
from superset.utils.decorators import logs_context from superset.utils.decorators import logs_context, transaction
from superset.utils.pdf import build_pdf_from_screenshots from superset.utils.pdf import build_pdf_from_screenshots
from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot
from superset.utils.urls import get_url_path from superset.utils.urls import get_url_path
@ -120,7 +120,6 @@ class BaseReportState:
self._report_schedule.last_state = state self._report_schedule.last_state = state
self._report_schedule.last_eval_dttm = datetime.utcnow() self._report_schedule.last_eval_dttm = datetime.utcnow()
db.session.commit()
def create_log(self, error_message: Optional[str] = None) -> None: def create_log(self, error_message: Optional[str] = None) -> None:
""" """
@ -138,7 +137,7 @@ class BaseReportState:
uuid=self._execution_id, uuid=self._execution_id,
) )
db.session.add(log) db.session.add(log)
db.session.commit() db.session.commit() # pylint: disable=consider-using-transaction
def _get_url( def _get_url(
self, self,
@ -690,6 +689,7 @@ class ReportScheduleStateMachine: # pylint: disable=too-few-public-methods
self._report_schedule = report_schedule self._report_schedule = report_schedule
self._scheduled_dttm = scheduled_dttm self._scheduled_dttm = scheduled_dttm
@transaction()
def run(self) -> None: def run(self) -> None:
for state_cls in self.states_cls: for state_cls in self.states_cls:
if (self._report_schedule.last_state is None and state_cls.initial) or ( if (self._report_schedule.last_state is None and state_cls.initial) or (
@ -718,6 +718,7 @@ class AsyncExecuteReportScheduleCommand(BaseCommand):
self._scheduled_dttm = scheduled_dttm self._scheduled_dttm = scheduled_dttm
self._execution_id = UUID(task_id) self._execution_id = UUID(task_id)
@transaction()
def run(self) -> None: def run(self) -> None:
try: try:
self.validate() self.validate()

View File

@ -17,12 +17,14 @@
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from sqlalchemy.exc import SQLAlchemyError
from superset import db from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.report.exceptions import ReportSchedulePruneLogError from superset.commands.report.exceptions import ReportSchedulePruneLogError
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO from superset.daos.report import ReportScheduleDAO
from superset.reports.models import ReportSchedule from superset.reports.models import ReportSchedule
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,9 +34,7 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
Prunes logs from all report schedules Prunes logs from all report schedules
""" """
def __init__(self, worker_context: bool = True): @transaction()
self._worker_context = worker_context
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
prune_errors = [] prune_errors = []
@ -46,15 +46,15 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand):
) )
try: try:
row_count = ReportScheduleDAO.bulk_delete_logs( row_count = ReportScheduleDAO.bulk_delete_logs(
report_schedule, from_date, commit=False report_schedule,
from_date,
) )
db.session.commit()
logger.info( logger.info(
"Deleted %s logs for report schedule id: %s", "Deleted %s logs for report schedule id: %s",
str(row_count), str(row_count),
str(report_schedule.id), str(report_schedule.id),
) )
except DAODeleteFailedError as ex: except SQLAlchemyError as ex:
prune_errors.append(str(ex)) prune_errors.append(str(ex))
if prune_errors: if prune_errors:
raise ReportSchedulePruneLogError(";".join(prune_errors)) raise ReportSchedulePruneLogError(";".join(prune_errors))

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any, Optional from typing import Any, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -32,11 +33,11 @@ from superset.commands.report.exceptions import (
ReportScheduleUpdateFailedError, ReportScheduleUpdateFailedError,
) )
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.daos.report import ReportScheduleDAO from superset.daos.report import ReportScheduleDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.reports.models import ReportSchedule, ReportScheduleType, ReportState from superset.reports.models import ReportSchedule, ReportScheduleType, ReportState
from superset.utils import json from superset.utils import json
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,16 +48,10 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand):
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[ReportSchedule] = None self._model: Optional[ReportSchedule] = None
@transaction(on_error=partial(on_error, reraise=ReportScheduleUpdateFailedError))
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
assert self._model return ReportScheduleDAO.update(self._model, self._properties)
try:
report_schedule = ReportScheduleDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise ReportScheduleUpdateFailedError() from ex
return report_schedule
def validate(self) -> None: def validate(self) -> None:
""" """

View File

@ -23,9 +23,9 @@ from superset.commands.base import BaseCommand
from superset.commands.exceptions import DatasourceNotFoundValidationError from superset.commands.exceptions import DatasourceNotFoundValidationError
from superset.commands.utils import populate_roles from superset.commands.utils import populate_roles
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.daos.exceptions import DAOCreateFailedError
from superset.daos.security import RLSDAO from superset.daos.security import RLSDAO
from superset.extensions import db from superset.extensions import db
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,13 +36,10 @@ class CreateRLSRuleCommand(BaseCommand):
self._tables = self._properties.get("tables", []) self._tables = self._properties.get("tables", [])
self._roles = self._properties.get("roles", []) self._roles = self._properties.get("roles", [])
@transaction()
def run(self) -> Any: def run(self) -> Any:
self.validate() self.validate()
try: return RLSDAO.create(attributes=self._properties)
return RLSDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise
def validate(self) -> None: def validate(self) -> None:
roles = populate_roles(self._roles) roles = populate_roles(self._roles)

View File

@ -16,15 +16,16 @@
# under the License. # under the License.
import logging import logging
from functools import partial
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.security.exceptions import ( from superset.commands.security.exceptions import (
RLSRuleNotFoundError, RLSRuleNotFoundError,
RuleDeleteFailedError, RuleDeleteFailedError,
) )
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.security import RLSDAO from superset.daos.security import RLSDAO
from superset.reports.models import ReportSchedule from superset.reports.models import ReportSchedule
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,13 +35,10 @@ class DeleteRLSRuleCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: list[ReportSchedule] = [] self._models: list[ReportSchedule] = []
@transaction(on_error=partial(on_error, reraise=RuleDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
try: RLSDAO.delete(self._models)
RLSDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise RuleDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
# Validate/populate model exists # Validate/populate model exists

View File

@ -24,9 +24,9 @@ from superset.commands.exceptions import DatasourceNotFoundValidationError
from superset.commands.security.exceptions import RLSRuleNotFoundError from superset.commands.security.exceptions import RLSRuleNotFoundError
from superset.commands.utils import populate_roles from superset.commands.utils import populate_roles
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.daos.exceptions import DAOUpdateFailedError
from superset.daos.security import RLSDAO from superset.daos.security import RLSDAO
from superset.extensions import db from superset.extensions import db
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,17 +39,11 @@ class UpdateRLSRuleCommand(BaseCommand):
self._roles = self._properties.get("roles", []) self._roles = self._properties.get("roles", [])
self._model: Optional[RowLevelSecurityFilter] = None self._model: Optional[RowLevelSecurityFilter] = None
@transaction()
def run(self) -> Any: def run(self) -> Any:
self.validate() self.validate()
assert self._model assert self._model
return RLSDAO.update(self._model, self._properties)
try:
rule = RLSDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise
return rule
def validate(self) -> None: def validate(self) -> None:
self._model = RLSDAO.find_by_id(int(self._model_id)) self._model = RLSDAO.find_by_id(int(self._model_id))

View File

@ -22,10 +22,11 @@ import logging
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
from flask_babel import gettext as __ from flask_babel import gettext as __
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.common.db_query_status import QueryStatus from superset.common.db_query_status import QueryStatus
from superset.daos.exceptions import DAOCreateFailedError
from superset.errors import SupersetErrorType from superset.errors import SupersetErrorType
from superset.exceptions import ( from superset.exceptions import (
SupersetErrorException, SupersetErrorException,
@ -41,6 +42,7 @@ from superset.sqllab.exceptions import (
) )
from superset.sqllab.execution_context_convertor import ExecutionContextConvertor from superset.sqllab.execution_context_convertor import ExecutionContextConvertor
from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.decorators import transaction
if TYPE_CHECKING: if TYPE_CHECKING:
from superset.daos.database import DatabaseDAO from superset.daos.database import DatabaseDAO
@ -90,6 +92,7 @@ class ExecuteSqlCommand(BaseCommand):
def validate(self) -> None: def validate(self) -> None:
pass pass
@transaction()
def run( # pylint: disable=too-many-statements,useless-suppression def run( # pylint: disable=too-many-statements,useless-suppression
self, self,
) -> CommandResult: ) -> CommandResult:
@ -178,9 +181,22 @@ class ExecuteSqlCommand(BaseCommand):
) )
def _save_new_query(self, query: Query) -> None: def _save_new_query(self, query: Query) -> None:
"""
Saves the new SQL Lab query.
Committing within a transaction violates the "unit of work" construct, but is
necessary for async querying. The Celery task is defined within the confines
of another command and needs to read a previously committed state given the
`READ COMMITTED` isolation level.
To mitigate said issue, ideally there would be a command to prepare said query
and another to execute it, either in a sync or async manner.
:param query: The SQL Lab query
"""
try: try:
self._query_dao.create(query) self._query_dao.create(query)
except DAOCreateFailedError as ex: except SQLAlchemyError as ex:
raise SqlLabException( raise SqlLabException(
self._execution_context, self._execution_context,
SupersetErrorType.GENERIC_DB_ENGINE_ERROR, SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
@ -189,6 +205,8 @@ class ExecuteSqlCommand(BaseCommand):
"Please contact an administrator for further assistance or try again.", "Please contact an administrator for further assistance or try again.",
) from ex ) from ex
db.session.commit() # pylint: disable=consider-using-transaction
def _validate_access(self, query: Query) -> None: def _validate_access(self, query: Query) -> None:
try: try:
self._access_validator.validate(query) self._access_validator.validate(query)

View File

@ -15,16 +15,17 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from typing import Any from typing import Any
from superset import db, security_manager from superset import security_manager
from superset.commands.base import BaseCommand, CreateMixin from superset.commands.base import BaseCommand, CreateMixin
from superset.commands.tag.exceptions import TagCreateFailedError, TagInvalidError from superset.commands.tag.exceptions import TagCreateFailedError, TagInvalidError
from superset.commands.tag.utils import to_object_model, to_object_type from superset.commands.tag.utils import to_object_model, to_object_type
from superset.daos.exceptions import DAOCreateFailedError
from superset.daos.tag import TagDAO from superset.daos.tag import TagDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.tags.models import ObjectType, TagType from superset.tags.models import ObjectType, TagType
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,20 +36,18 @@ class CreateCustomTagCommand(CreateMixin, BaseCommand):
self._object_id = object_id self._object_id = object_id
self._tags = tags self._tags = tags
@transaction(on_error=partial(on_error, reraise=TagCreateFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
try: object_type = to_object_type(self._object_type)
object_type = to_object_type(self._object_type) if object_type is None:
if object_type is None: raise TagCreateFailedError(f"invalid object type {self._object_type}")
raise TagCreateFailedError(f"invalid object type {self._object_type}")
TagDAO.create_custom_tagged_objects( TagDAO.create_custom_tagged_objects(
object_type=object_type, object_type=object_type,
object_id=self._object_id, object_id=self._object_id,
tag_names=self._tags, tag_names=self._tags,
) )
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise TagCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
exceptions = [] exceptions = []
@ -71,27 +70,20 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
self._bulk_create = bulk_create self._bulk_create = bulk_create
self._skipped_tagged_objects: set[tuple[str, int]] = set() self._skipped_tagged_objects: set[tuple[str, int]] = set()
@transaction(on_error=partial(on_error, reraise=TagCreateFailedError))
def run(self) -> tuple[set[tuple[str, int]], set[tuple[str, int]]]: def run(self) -> tuple[set[tuple[str, int]], set[tuple[str, int]]]:
self.validate() self.validate()
try: tag_name = self._properties["name"]
tag_name = self._properties["name"] tag = TagDAO.get_by_name(tag_name.strip(), TagType.custom)
tag = TagDAO.get_by_name(tag_name.strip(), TagType.custom) TagDAO.create_tag_relationship(
TagDAO.create_tag_relationship( objects_to_tag=self._properties.get("objects_to_tag", []),
objects_to_tag=self._properties.get("objects_to_tag", []), tag=tag,
tag=tag, bulk_create=self._bulk_create,
bulk_create=self._bulk_create, )
)
tag.description = self._properties.get("description", "") tag.description = self._properties.get("description", "")
return set(self._properties["objects_to_tag"]), self._skipped_tagged_objects
db.session.commit()
return set(self._properties["objects_to_tag"]), self._skipped_tagged_objects
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise TagCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
exceptions = [] exceptions = []

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from functools import partial
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.tag.exceptions import ( from superset.commands.tag.exceptions import (
@ -25,9 +26,9 @@ from superset.commands.tag.exceptions import (
TagNotFoundError, TagNotFoundError,
) )
from superset.commands.tag.utils import to_object_type from superset.commands.tag.utils import to_object_type
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.tag import TagDAO from superset.daos.tag import TagDAO
from superset.tags.models import ObjectType from superset.tags.models import ObjectType
from superset.utils.decorators import on_error, transaction
from superset.views.base import DeleteMixin from superset.views.base import DeleteMixin
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,18 +40,15 @@ class DeleteTaggedObjectCommand(DeleteMixin, BaseCommand):
self._object_id = object_id self._object_id = object_id
self._tag = tag self._tag = tag
@transaction(on_error=partial(on_error, reraise=TaggedObjectDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
try: object_type = to_object_type(self._object_type)
object_type = to_object_type(self._object_type) if object_type is None:
if object_type is None: raise TaggedObjectDeleteFailedError(
raise TaggedObjectDeleteFailedError( f"invalid object type {self._object_type}"
f"invalid object type {self._object_type}" )
) TagDAO.delete_tagged_object(object_type, self._object_id, self._tag)
TagDAO.delete_tagged_object(object_type, self._object_id, self._tag)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise TaggedObjectDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
exceptions = [] exceptions = []
@ -92,13 +90,10 @@ class DeleteTagsCommand(DeleteMixin, BaseCommand):
def __init__(self, tags: list[str]): def __init__(self, tags: list[str]):
self._tags = tags self._tags = tags
@transaction(on_error=partial(on_error, reraise=TagDeleteFailedError))
def run(self) -> None: def run(self) -> None:
self.validate() self.validate()
try: TagDAO.delete_tags(self._tags)
TagDAO.delete_tags(self._tags)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise TagDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
exceptions = [] exceptions = []

View File

@ -25,6 +25,7 @@ from superset.commands.tag.exceptions import TagInvalidError, TagNotFoundError
from superset.commands.tag.utils import to_object_type from superset.commands.tag.utils import to_object_type
from superset.daos.tag import TagDAO from superset.daos.tag import TagDAO
from superset.tags.models import Tag from superset.tags.models import Tag
from superset.utils.decorators import transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,18 +36,17 @@ class UpdateTagCommand(UpdateMixin, BaseCommand):
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[Tag] = None self._model: Optional[Tag] = None
@transaction()
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
if self._model: assert self._model
self._model.name = self._properties["name"] self._model.name = self._properties["name"]
TagDAO.create_tag_relationship( TagDAO.create_tag_relationship(
objects_to_tag=self._properties.get("objects_to_tag", []), objects_to_tag=self._properties.get("objects_to_tag", []),
tag=self._model, tag=self._model,
) )
self._model.description = self._properties.get("description") self._model.description = self._properties.get("description")
db.session.add(self._model)
db.session.add(self._model)
db.session.commit()
return self._model return self._model

View File

@ -16,12 +16,12 @@
# under the License. # under the License.
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.temporary_cache.exceptions import TemporaryCacheCreateFailedError from superset.commands.temporary_cache.exceptions import TemporaryCacheCreateFailedError
from superset.commands.temporary_cache.parameters import CommandParameters from superset.commands.temporary_cache.parameters import CommandParameters
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,12 +30,9 @@ class CreateTemporaryCacheCommand(BaseCommand, ABC):
def __init__(self, cmd_params: CommandParameters): def __init__(self, cmd_params: CommandParameters):
self._cmd_params = cmd_params self._cmd_params = cmd_params
@transaction(on_error=partial(on_error, reraise=TemporaryCacheCreateFailedError))
def run(self) -> str: def run(self) -> str:
try: return self.create(self._cmd_params)
return self.create(self._cmd_params)
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise TemporaryCacheCreateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -16,12 +16,12 @@
# under the License. # under the License.
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.temporary_cache.exceptions import TemporaryCacheDeleteFailedError from superset.commands.temporary_cache.exceptions import TemporaryCacheDeleteFailedError
from superset.commands.temporary_cache.parameters import CommandParameters from superset.commands.temporary_cache.parameters import CommandParameters
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,12 +30,9 @@ class DeleteTemporaryCacheCommand(BaseCommand, ABC):
def __init__(self, cmd_params: CommandParameters): def __init__(self, cmd_params: CommandParameters):
self._cmd_params = cmd_params self._cmd_params = cmd_params
@transaction(on_error=partial(on_error, reraise=TemporaryCacheDeleteFailedError))
def run(self) -> bool: def run(self) -> bool:
try: return self.delete(self._cmd_params)
return self.delete(self._cmd_params)
except SQLAlchemyError as ex:
logger.exception("Error running delete command")
raise TemporaryCacheDeleteFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -16,13 +16,13 @@
# under the License. # under the License.
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial
from typing import Optional from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.temporary_cache.exceptions import TemporaryCacheUpdateFailedError from superset.commands.temporary_cache.exceptions import TemporaryCacheUpdateFailedError
from superset.commands.temporary_cache.parameters import CommandParameters from superset.commands.temporary_cache.parameters import CommandParameters
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,12 +34,9 @@ class UpdateTemporaryCacheCommand(BaseCommand, ABC):
): ):
self._parameters = cmd_params self._parameters = cmd_params
@transaction(on_error=partial(on_error, reraise=TemporaryCacheUpdateFailedError))
def run(self) -> Optional[str]: def run(self) -> Optional[str]:
try: return self.update(self._parameters)
return self.update(self._parameters)
except SQLAlchemyError as ex:
logger.exception("Error running update command")
raise TemporaryCacheUpdateFailedError() from ex
def validate(self) -> None: def validate(self) -> None:
pass pass

View File

@ -1768,11 +1768,10 @@ class SqlaTable(
) )
) )
def fetch_metadata(self, commit: bool = True) -> MetadataResult: def fetch_metadata(self) -> MetadataResult:
""" """
Fetches the metadata for the table and merges it in Fetches the metadata for the table and merges it in
:param commit: should the changes be committed or not.
:return: Tuple with lists of added, removed and modified column names. :return: Tuple with lists of added, removed and modified column names.
""" """
new_columns = self.external_metadata() new_columns = self.external_metadata()
@ -1850,8 +1849,6 @@ class SqlaTable(
config["SQLA_TABLE_MUTATOR"](self) config["SQLA_TABLE_MUTATOR"](self)
db.session.merge(self) db.session.merge(self)
if commit:
db.session.commit()
return results return results
@classmethod @classmethod

View File

@ -21,13 +21,8 @@ from typing import Any, Generic, get_args, TypeVar
from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError, StatementError from sqlalchemy.exc import StatementError
from superset.daos.exceptions import (
DAOCreateFailedError,
DAODeleteFailedError,
DAOUpdateFailedError,
)
from superset.extensions import db from superset.extensions import db
T = TypeVar("T", bound=Model) T = TypeVar("T", bound=Model)
@ -127,15 +122,12 @@ class BaseDAO(Generic[T]):
cls, cls,
item: T | None = None, item: T | None = None,
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> T: ) -> T:
""" """
Create an object from the specified item and/or attributes. Create an object from the specified item and/or attributes.
:param item: The object to create :param item: The object to create
:param attributes: The attributes associated with the object to create :param attributes: The attributes associated with the object to create
:param commit: Whether to commit the transaction
:raises DAOCreateFailedError: If the creation failed
""" """
if not item: if not item:
@ -145,15 +137,7 @@ class BaseDAO(Generic[T]):
for key, value in attributes.items(): for key, value in attributes.items():
setattr(item, key, value) setattr(item, key, value)
try: db.session.add(item)
db.session.add(item)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
return item # type: ignore return item # type: ignore
@classmethod @classmethod
@ -161,15 +145,12 @@ class BaseDAO(Generic[T]):
cls, cls,
item: T | None = None, item: T | None = None,
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> T: ) -> T:
""" """
Update an object from the specified item and/or attributes. Update an object from the specified item and/or attributes.
:param item: The object to update :param item: The object to update
:param attributes: The attributes associated with the object to update :param attributes: The attributes associated with the object to update
:param commit: Whether to commit the transaction
:raises DAOUpdateFailedError: If the updating failed
""" """
if not item: if not item:
@ -179,19 +160,13 @@ class BaseDAO(Generic[T]):
for key, value in attributes.items(): for key, value in attributes.items():
setattr(item, key, value) setattr(item, key, value)
try: if item not in db.session:
db.session.merge(item) return db.session.merge(item)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOUpdateFailedError(exception=ex) from ex
return item # type: ignore return item # type: ignore
@classmethod @classmethod
def delete(cls, items: list[T], commit: bool = True) -> None: def delete(cls, items: list[T]) -> None:
""" """
Delete the specified items including their associated relationships. Delete the specified items including their associated relationships.
@ -204,17 +179,8 @@ class BaseDAO(Generic[T]):
post-deletion logic. post-deletion logic.
:param items: The items to delete :param items: The items to delete
:param commit: Whether to commit the transaction
:raises DAODeleteFailedError: If the deletion failed
:see: https://docs.sqlalchemy.org/en/latest/orm/queryguide/dml.html :see: https://docs.sqlalchemy.org/en/latest/orm/queryguide/dml.html
""" """
try: for item in items:
for item in items: db.session.delete(item)
db.session.delete(item)
if commit:
db.session.commit()
except SQLAlchemyError as ex:
db.session.rollback()
raise DAODeleteFailedError(exception=ex) from ex

View File

@ -62,7 +62,6 @@ class ChartDAO(BaseDAO[Slice]):
dttm=datetime.now(), dttm=datetime.now(),
) )
) )
db.session.commit()
@staticmethod @staticmethod
def remove_favorite(chart: Slice) -> None: def remove_favorite(chart: Slice) -> None:
@ -77,4 +76,3 @@ class ChartDAO(BaseDAO[Slice]):
) )
if fav: if fav:
db.session.delete(fav) db.session.delete(fav)
db.session.commit()

View File

@ -179,8 +179,7 @@ class DashboardDAO(BaseDAO[Dashboard]):
dashboard: Dashboard, dashboard: Dashboard,
data: dict[Any, Any], data: dict[Any, Any],
old_to_new_slice_ids: dict[int, int] | None = None, old_to_new_slice_ids: dict[int, int] | None = None,
commit: bool = False, ) -> None:
) -> Dashboard:
new_filter_scopes = {} new_filter_scopes = {}
md = dashboard.params_dict md = dashboard.params_dict
@ -265,10 +264,6 @@ class DashboardDAO(BaseDAO[Dashboard]):
md["cross_filters_enabled"] = data.get("cross_filters_enabled", True) md["cross_filters_enabled"] = data.get("cross_filters_enabled", True)
dashboard.json_metadata = json.dumps(md) dashboard.json_metadata = json.dumps(md)
if commit:
db.session.commit()
return dashboard
@staticmethod @staticmethod
def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]: def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]:
ids = [dash.id for dash in dashboards] ids = [dash.id for dash in dashboards]
@ -321,7 +316,6 @@ class DashboardDAO(BaseDAO[Dashboard]):
dash.params = original_dash.params dash.params = original_dash.params
cls.set_dash_metadata(dash, metadata, old_to_new_slice_ids) cls.set_dash_metadata(dash, metadata, old_to_new_slice_ids)
db.session.add(dash) db.session.add(dash)
db.session.commit()
return dash return dash
@staticmethod @staticmethod
@ -336,7 +330,6 @@ class DashboardDAO(BaseDAO[Dashboard]):
dttm=datetime.now(), dttm=datetime.now(),
) )
) )
db.session.commit()
@staticmethod @staticmethod
def remove_favorite(dashboard: Dashboard) -> None: def remove_favorite(dashboard: Dashboard) -> None:
@ -351,7 +344,6 @@ class DashboardDAO(BaseDAO[Dashboard]):
) )
if fav: if fav:
db.session.delete(fav) db.session.delete(fav)
db.session.commit()
class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]): class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
@ -369,7 +361,6 @@ class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
) )
embedded.allow_domain_list = ",".join(allowed_domains) embedded.allow_domain_list = ",".join(allowed_domains)
dashboard.embedded = [embedded] dashboard.embedded = [embedded]
db.session.commit()
return embedded return embedded
@classmethod @classmethod
@ -377,7 +368,6 @@ class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
cls, cls,
item: EmbeddedDashboardDAO | None = None, item: EmbeddedDashboardDAO | None = None,
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Any: ) -> Any:
""" """
Use EmbeddedDashboardDAO.upsert() instead. Use EmbeddedDashboardDAO.upsert() instead.

View File

@ -42,7 +42,6 @@ class DatabaseDAO(BaseDAO[Database]):
cls, cls,
item: Database | None = None, item: Database | None = None,
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Database: ) -> Database:
""" """
Unmask ``encrypted_extra`` before updating. Unmask ``encrypted_extra`` before updating.
@ -60,7 +59,7 @@ class DatabaseDAO(BaseDAO[Database]):
attributes["encrypted_extra"], attributes["encrypted_extra"],
) )
return super().update(item, attributes, commit) return super().update(item, attributes)
@staticmethod @staticmethod
def validate_uniqueness(database_name: str) -> bool: def validate_uniqueness(database_name: str) -> bool:
@ -174,7 +173,6 @@ class SSHTunnelDAO(BaseDAO[SSHTunnel]):
cls, cls,
item: SSHTunnel | None = None, item: SSHTunnel | None = None,
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> SSHTunnel: ) -> SSHTunnel:
""" """
Unmask ``password``, ``private_key`` and ``private_key_password`` before updating. Unmask ``password``, ``private_key`` and ``private_key_password`` before updating.
@ -190,7 +188,7 @@ class SSHTunnelDAO(BaseDAO[SSHTunnel]):
attributes.pop("id", None) attributes.pop("id", None)
attributes = unmask_password_info(attributes, item) attributes = unmask_password_info(attributes, item)
return super().update(item, attributes, commit) return super().update(item, attributes)
class DatabaseUserOAuth2TokensDAO(BaseDAO[DatabaseUserOAuth2Tokens]): class DatabaseUserOAuth2TokensDAO(BaseDAO[DatabaseUserOAuth2Tokens]):

View File

@ -25,7 +25,6 @@ from sqlalchemy.exc import SQLAlchemyError
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.daos.base import BaseDAO from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.extensions import db from superset.extensions import db
from superset.models.core import Database from superset.models.core import Database
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
@ -171,7 +170,6 @@ class DatasetDAO(BaseDAO[SqlaTable]):
cls, cls,
item: SqlaTable | None = None, item: SqlaTable | None = None,
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> SqlaTable: ) -> SqlaTable:
""" """
Updates a Dataset model on the metadata DB Updates a Dataset model on the metadata DB
@ -182,21 +180,19 @@ class DatasetDAO(BaseDAO[SqlaTable]):
cls.update_columns( cls.update_columns(
item, item,
attributes.pop("columns"), attributes.pop("columns"),
commit=commit,
override_columns=bool(attributes.get("override_columns")), override_columns=bool(attributes.get("override_columns")),
) )
if "metrics" in attributes: if "metrics" in attributes:
cls.update_metrics(item, attributes.pop("metrics"), commit=commit) cls.update_metrics(item, attributes.pop("metrics"))
return super().update(item, attributes, commit=commit) return super().update(item, attributes)
@classmethod @classmethod
def update_columns( def update_columns(
cls, cls,
model: SqlaTable, model: SqlaTable,
property_columns: list[dict[str, Any]], property_columns: list[dict[str, Any]],
commit: bool = True,
override_columns: bool = False, override_columns: bool = False,
) -> None: ) -> None:
""" """
@ -217,7 +213,7 @@ class DatasetDAO(BaseDAO[SqlaTable]):
if not DatasetDAO.validate_python_date_format( if not DatasetDAO.validate_python_date_format(
column["python_date_format"] column["python_date_format"]
): ):
raise DAOUpdateFailedError( raise ValueError(
"python_date_format is an invalid date/timestamp format." "python_date_format is an invalid date/timestamp format."
) )
@ -266,15 +262,11 @@ class DatasetDAO(BaseDAO[SqlaTable]):
) )
).delete(synchronize_session="fetch") ).delete(synchronize_session="fetch")
if commit:
db.session.commit()
@classmethod @classmethod
def update_metrics( def update_metrics(
cls, cls,
model: SqlaTable, model: SqlaTable,
property_metrics: list[dict[str, Any]], property_metrics: list[dict[str, Any]],
commit: bool = True,
) -> None: ) -> None:
""" """
Creates/updates and/or deletes a list of metrics, based on a Creates/updates and/or deletes a list of metrics, based on a
@ -317,9 +309,6 @@ class DatasetDAO(BaseDAO[SqlaTable]):
) )
).delete(synchronize_session="fetch") ).delete(synchronize_session="fetch")
if commit:
db.session.commit()
@classmethod @classmethod
def find_dataset_column(cls, dataset_id: int, column_id: int) -> TableColumn | None: def find_dataset_column(cls, dataset_id: int, column_id: int) -> TableColumn | None:
# We want to apply base dataset filters # We want to apply base dataset filters

View File

@ -23,30 +23,6 @@ class DAOException(SupersetException):
""" """
class DAOCreateFailedError(DAOException):
"""
DAO Create failed
"""
message = "Create failed"
class DAOUpdateFailedError(DAOException):
"""
DAO Update failed
"""
message = "Update failed"
class DAODeleteFailedError(DAOException):
"""
DAO Delete failed
"""
message = "Delete failed"
class DatasourceTypeNotSupportedError(DAOException): class DatasourceTypeNotSupportedError(DAOException):
""" """
DAO datasource query source type is not supported DAO datasource query source type is not supported

View File

@ -53,7 +53,6 @@ class QueryDAO(BaseDAO[Query]):
for saved_query in related_saved_queries: for saved_query in related_saved_queries:
saved_query.rows = query.rows saved_query.rows = query.rows
saved_query.last_run = datetime.now() saved_query.last_run = datetime.now()
db.session.commit()
@staticmethod @staticmethod
def save_metadata(query: Query, payload: dict[str, Any]) -> None: def save_metadata(query: Query, payload: dict[str, Any]) -> None:
@ -97,7 +96,6 @@ class QueryDAO(BaseDAO[Query]):
query.status = QueryStatus.STOPPED query.status = QueryStatus.STOPPED
query.end_time = now_as_float() query.end_time = now_as_float()
db.session.commit()
class SavedQueryDAO(BaseDAO[SavedQuery]): class SavedQueryDAO(BaseDAO[SavedQuery]):

View File

@ -20,10 +20,7 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from sqlalchemy.exc import SQLAlchemyError
from superset.daos.base import BaseDAO from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.extensions import db from superset.extensions import db
from superset.reports.filters import ReportScheduleFilter from superset.reports.filters import ReportScheduleFilter
from superset.reports.models import ( from superset.reports.models import (
@ -137,15 +134,12 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
cls, cls,
item: ReportSchedule | None = None, item: ReportSchedule | None = None,
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> ReportSchedule: ) -> ReportSchedule:
""" """
Create a report schedule with nested recipients. Create a report schedule with nested recipients.
:param item: The object to create :param item: The object to create
:param attributes: The attributes associated with the object to create :param attributes: The attributes associated with the object to create
:param commit: Whether to commit the transaction
:raises: DAOCreateFailedError: If the creation failed
""" """
# TODO(john-bodley): Determine why we need special handling for recipients. # TODO(john-bodley): Determine why we need special handling for recipients.
@ -165,22 +159,19 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
for recipient in recipients for recipient in recipients
] ]
return super().create(item, attributes, commit) return super().create(item, attributes)
@classmethod @classmethod
def update( def update(
cls, cls,
item: ReportSchedule | None = None, item: ReportSchedule | None = None,
attributes: dict[str, Any] | None = None, attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> ReportSchedule: ) -> ReportSchedule:
""" """
Update a report schedule with nested recipients. Update a report schedule with nested recipients.
:param item: The object to update :param item: The object to update
:param attributes: The attributes associated with the object to update :param attributes: The attributes associated with the object to update
:param commit: Whether to commit the transaction
:raises: DAOUpdateFailedError: If the update failed
""" """
# TODO(john-bodley): Determine why we need special handling for recipients. # TODO(john-bodley): Determine why we need special handling for recipients.
@ -200,7 +191,7 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
for recipient in recipients for recipient in recipients
] ]
return super().update(item, attributes, commit) return super().update(item, attributes)
@staticmethod @staticmethod
def find_active() -> list[ReportSchedule]: def find_active() -> list[ReportSchedule]:
@ -283,23 +274,12 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
return last_error_email_log if not report_from_last_email else None return last_error_email_log if not report_from_last_email else None
@staticmethod @staticmethod
def bulk_delete_logs( def bulk_delete_logs(model: ReportSchedule, from_date: datetime) -> int | None:
model: ReportSchedule, return (
from_date: datetime, db.session.query(ReportExecutionLog)
commit: bool = True, .filter(
) -> int | None: ReportExecutionLog.report_schedule == model,
try: ReportExecutionLog.end_dttm < from_date,
row_count = (
db.session.query(ReportExecutionLog)
.filter(
ReportExecutionLog.report_schedule == model,
ReportExecutionLog.end_dttm < from_date,
)
.delete(synchronize_session="fetch")
) )
if commit: .delete(synchronize_session="fetch")
db.session.commit() )
return row_count
except SQLAlchemyError as ex:
db.session.rollback()
raise DAODeleteFailedError(str(ex)) from ex

View File

@ -19,12 +19,11 @@ from operator import and_
from typing import Any, Optional from typing import Any, Optional
from flask import g from flask import g
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import NoResultFound
from superset.commands.tag.exceptions import TagNotFoundError from superset.commands.tag.exceptions import TagNotFoundError
from superset.commands.tag.utils import to_object_type from superset.commands.tag.utils import to_object_type
from superset.daos.base import BaseDAO from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.exceptions import MissingUserContextException from superset.exceptions import MissingUserContextException
from superset.extensions import db from superset.extensions import db
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
@ -75,7 +74,6 @@ class TagDAO(BaseDAO[Tag]):
) )
db.session.add_all(tagged_objects) db.session.add_all(tagged_objects)
db.session.commit()
@staticmethod @staticmethod
def delete_tagged_object( def delete_tagged_object(
@ -86,9 +84,7 @@ class TagDAO(BaseDAO[Tag]):
""" """
tag = TagDAO.find_by_name(tag_name.strip()) tag = TagDAO.find_by_name(tag_name.strip())
if not tag: if not tag:
raise DAODeleteFailedError( raise NoResultFound(message=f"Tag with name {tag_name} does not exist.")
message=f"Tag with name {tag_name} does not exist."
)
tagged_object = db.session.query(TaggedObject).filter( tagged_object = db.session.query(TaggedObject).filter(
TaggedObject.tag_id == tag.id, TaggedObject.tag_id == tag.id,
@ -96,17 +92,13 @@ class TagDAO(BaseDAO[Tag]):
TaggedObject.object_id == object_id, TaggedObject.object_id == object_id,
) )
if not tagged_object: if not tagged_object:
raise DAODeleteFailedError( raise NoResultFound(
message=f'Tagged object with object_id: {object_id} \ message=f'Tagged object with object_id: {object_id} \
object_type: {object_type} \ object_type: {object_type} \
and tag name: "{tag_name}" could not be found' and tag name: "{tag_name}" could not be found'
) )
try:
db.session.delete(tagged_object.one()) db.session.delete(tagged_object.one())
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAODeleteFailedError(exception=ex) from ex
@staticmethod @staticmethod
def delete_tags(tag_names: list[str]) -> None: def delete_tags(tag_names: list[str]) -> None:
@ -117,18 +109,12 @@ class TagDAO(BaseDAO[Tag]):
for name in tag_names: for name in tag_names:
tag_name = name.strip() tag_name = name.strip()
if not TagDAO.find_by_name(tag_name): if not TagDAO.find_by_name(tag_name):
raise DAODeleteFailedError( raise NoResultFound(message=f"Tag with name {tag_name} does not exist.")
message=f"Tag with name {tag_name} does not exist."
)
tags_to_delete.append(tag_name) tags_to_delete.append(tag_name)
tag_objects = db.session.query(Tag).filter(Tag.name.in_(tags_to_delete)) tag_objects = db.session.query(Tag).filter(Tag.name.in_(tags_to_delete))
for tag in tag_objects: for tag in tag_objects:
try: db.session.delete(tag)
db.session.delete(tag)
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAODeleteFailedError(exception=ex) from ex
@staticmethod @staticmethod
def get_by_name(name: str, type_: TagType = TagType.custom) -> Tag: def get_by_name(name: str, type_: TagType = TagType.custom) -> Tag:
@ -283,21 +269,10 @@ class TagDAO(BaseDAO[Tag]):
) -> None: ) -> None:
""" """
Marks a specific tag as a favorite for the current user. Marks a specific tag as a favorite for the current user.
This function will find the tag by the provided id,
create a new UserFavoriteTag object that represents :param tag_id: The id of the tag that is to be marked as favorite
the user's preference, add that object to the database
session, and commit the session. It uses the currently
authenticated user from the global 'g' object.
Args:
tag_id: The id of the tag that is to be marked as
favorite.
Raises:
Any exceptions raised by the find_by_id function,
the UserFavoriteTag constructor, or the database session's
add and commit methods will propagate up to the caller.
Returns:
None.
""" """
tag = TagDAO.find_by_id(tag_id) tag = TagDAO.find_by_id(tag_id)
user = g.user user = g.user
@ -307,26 +282,13 @@ class TagDAO(BaseDAO[Tag]):
raise TagNotFoundError() raise TagNotFoundError()
tag.users_favorited.append(user) tag.users_favorited.append(user)
db.session.commit()
@staticmethod @staticmethod
def remove_user_favorite_tag(tag_id: int) -> None: def remove_user_favorite_tag(tag_id: int) -> None:
""" """
Removes a tag from the current user's favorite tags. Removes a tag from the current user's favorite tags.
This function will find the tag by the provided id and remove the tag :param tag_id: The id of the tag that is to be removed from the favorite tags
from the user's list of favorite tags. It uses the currently authenticated
user from the global 'g' object.
Args:
tag_id: The id of the tag that is to be removed from the favorite tags.
Raises:
Any exceptions raised by the find_by_id function, the database session's
commit method will propagate up to the caller.
Returns:
None.
""" """
tag = TagDAO.find_by_id(tag_id) tag = TagDAO.find_by_id(tag_id)
user = g.user user = g.user
@ -338,9 +300,6 @@ class TagDAO(BaseDAO[Tag]):
tag.users_favorited.remove(user) tag.users_favorited.remove(user)
# Commit to save the changes
db.session.commit()
@staticmethod @staticmethod
def favorited_ids(tags: list[Tag]) -> list[int]: def favorited_ids(tags: list[Tag]) -> list[int]:
""" """
@ -424,5 +383,4 @@ class TagDAO(BaseDAO[Tag]):
object_id, object_id,
tag.name, tag.name,
) )
db.session.add_all(tagged_objects) db.session.add_all(tagged_objects)

View File

@ -40,4 +40,3 @@ class UserDAO(BaseDAO[User]):
attrs = UserAttribute(avatar_url=url, user_id=user.id) attrs = UserAttribute(avatar_url=url, user_id=user.id)
user.extra_attributes = [attrs] user.extra_attributes = [attrs]
db.session.add(attrs) db.session.add(attrs)
db.session.commit()

View File

@ -32,7 +32,7 @@ from marshmallow import ValidationError
from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper from werkzeug.wsgi import FileWrapper
from superset import is_feature_enabled, thumbnail_cache from superset import db, is_feature_enabled, thumbnail_cache
from superset.charts.schemas import ChartEntityResponseSchema from superset.charts.schemas import ChartEntityResponseSchema
from superset.commands.dashboard.create import CreateDashboardCommand from superset.commands.dashboard.create import CreateDashboardCommand
from superset.commands.dashboard.delete import DeleteDashboardCommand from superset.commands.dashboard.delete import DeleteDashboardCommand
@ -1314,7 +1314,13 @@ class DashboardRestApi(BaseSupersetModelRestApi):
""" """
try: try:
body = self.embedded_config_schema.load(request.json) body = self.embedded_config_schema.load(request.json)
embedded = EmbeddedDashboardDAO.upsert(dashboard, body["allowed_domains"])
with db.session.begin_nested():
embedded = EmbeddedDashboardDAO.upsert(
dashboard,
body["allowed_domains"],
)
result = self.embedded_response_schema.dump(embedded) result = self.embedded_response_schema.dump(embedded)
return self.response(200, result=result) return self.response(200, result=result)
except ValidationError as error: except ValidationError as error:

View File

@ -1410,7 +1410,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
database_id=state["database_id"], database_id=state["database_id"],
) )
if existing: if existing:
DatabaseUserOAuth2TokensDAO.delete([existing], commit=True) DatabaseUserOAuth2TokensDAO.delete([existing])
# store tokens # store tokens
expiration = datetime.now() + timedelta(seconds=token_response["expires_in"]) expiration = datetime.now() + timedelta(seconds=token_response["expires_in"])
@ -1422,7 +1422,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"access_token_expiration": expiration, "access_token_expiration": expiration,
"refresh_token": token_response.get("refresh_token"), "refresh_token": token_response.get("refresh_token"),
}, },
commit=True,
) )
# return blank page that closes itself # return blank page that closes itself

View File

@ -455,4 +455,4 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
catalog[table.table] = spreadsheet_url catalog[table.table] = spreadsheet_url
database.extra = json.dumps(extra) database.extra = json.dumps(extra)
db.session.add(database) db.session.add(database)
db.session.commit() db.session.commit() # pylint: disable=consider-using-transaction

View File

@ -408,7 +408,7 @@ class HiveEngineSpec(PrestoEngineSpec):
logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l) logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l)
last_log_line = len(log_lines) last_log_line = len(log_lines)
if needs_commit: if needs_commit:
db.session.commit() db.session.commit() # pylint: disable=consider-using-transaction
if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"): if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"):
logger.warning( logger.warning(
"HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead" "HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead"

View File

@ -151,7 +151,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
needs_commit = True needs_commit = True
if needs_commit: if needs_commit:
db.session.commit() db.session.commit() # pylint: disable=consider-using-transaction
sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get( sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get(
cls.engine, 5 cls.engine, 5
) )

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=too-many-lines # pylint: disable=consider-using-transaction,too-many-lines
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=consider-using-transaction
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib

View File

@ -65,5 +65,4 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl.description = "BART lines" tbl.description = "BART lines"
tbl.database = database tbl.database = database
tbl.filter_select_enabled = True tbl.filter_select_enabled = True
db.session.commit()
tbl.fetch_metadata() tbl.fetch_metadata()

View File

@ -111,8 +111,6 @@ def load_birth_names(
_set_table_metadata(obj, database) _set_table_metadata(obj, database)
_add_table_metrics(obj) _add_table_metrics(obj)
db.session.commit()
slices, _ = create_slices(obj) slices, _ = create_slices(obj)
create_dashboard(slices) create_dashboard(slices)
@ -844,5 +842,4 @@ def create_dashboard(slices: list[Slice]) -> Dashboard:
dash.dashboard_title = "USA Births Names" dash.dashboard_title = "USA Births Names"
dash.position_json = json.dumps(pos, indent=4) dash.position_json = json.dumps(pos, indent=4)
dash.slug = "births" dash.slug = "births"
db.session.commit()
return dash return dash

View File

@ -88,7 +88,6 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
if not any(col.metric_name == "avg__2004" for col in obj.metrics): if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column("2004").compile(db.engine)) col = str(column("2004").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})")) obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})"))
db.session.commit()
obj.fetch_metadata() obj.fetch_metadata()
tbl = obj tbl = obj

View File

@ -52,7 +52,6 @@ def load_css_templates() -> None:
""" """
) )
obj.css = css obj.css = css
db.session.commit()
obj = db.session.query(CssTemplate).filter_by(template_name="Courier Black").first() obj = db.session.query(CssTemplate).filter_by(template_name="Courier Black").first()
if not obj: if not obj:
@ -97,4 +96,3 @@ def load_css_templates() -> None:
""" """
) )
obj.css = css obj.css = css
db.session.commit()

View File

@ -541,4 +541,3 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
dash.dashboard_title = title dash.dashboard_title = title
dash.slug = slug dash.slug = slug
dash.slices = slices dash.slices = slices
db.session.commit()

View File

@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Loads datasets, dashboards and slices in a new superset instance"""
import textwrap import textwrap
import pandas as pd import pandas as pd
@ -79,7 +77,6 @@ def load_energy(
SqlMetric(metric_name="sum__value", expression=f"SUM({col})") SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
) )
db.session.commit()
tbl.fetch_metadata() tbl.fetch_metadata()
slc = Slice( slc = Slice(

View File

@ -66,6 +66,5 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
tbl.description = "Random set of flights in the US" tbl.description = "Random set of flights in the US"
tbl.database = database tbl.database = database
tbl.filter_select_enabled = True tbl.filter_select_enabled = True
db.session.commit()
tbl.fetch_metadata() tbl.fetch_metadata()
print("Done loading table!") print("Done loading table!")

View File

@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Loads datasets, dashboards and slices in a new superset instance"""
import os import os
from typing import Any from typing import Any
@ -62,7 +60,6 @@ def merge_slice(slc: Slice) -> None:
if o: if o:
db.session.delete(o) db.session.delete(o)
db.session.add(slc) db.session.add(slc)
db.session.commit()
def get_slice_json(defaults: dict[Any, Any], **kwargs: Any) -> str: def get_slice_json(defaults: dict[Any, Any], **kwargs: Any) -> str:

View File

@ -97,7 +97,6 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
obj.main_dttm_col = "datetime" obj.main_dttm_col = "datetime"
obj.database = database obj.database = database
obj.filter_select_enabled = True obj.filter_select_enabled = True
db.session.commit()
obj.fetch_metadata() obj.fetch_metadata()
tbl = obj tbl = obj

View File

@ -140,4 +140,3 @@ def load_misc_dashboard() -> None:
dash.position_json = json.dumps(pos, indent=4) dash.position_json = json.dumps(pos, indent=4)
dash.slug = DASH_SLUG dash.slug = DASH_SLUG
dash.slices = slices dash.slices = slices
db.session.commit()

View File

@ -102,7 +102,6 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
col.python_date_format = dttm_and_expr[0] col.python_date_format = dttm_and_expr[0]
col.database_expression = dttm_and_expr[1] col.database_expression = dttm_and_expr[1]
col.is_dttm = True col.is_dttm = True
db.session.commit()
obj.fetch_metadata() obj.fetch_metadata()
tbl = obj tbl = obj

View File

@ -62,5 +62,4 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->
tbl.description = "Map of Paris" tbl.description = "Map of Paris"
tbl.database = database tbl.database = database
tbl.filter_select_enabled = True tbl.filter_select_enabled = True
db.session.commit()
tbl.fetch_metadata() tbl.fetch_metadata()

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import pandas as pd import pandas as pd
from sqlalchemy import DateTime, inspect, String from sqlalchemy import DateTime, inspect, String
@ -72,7 +71,6 @@ def load_random_time_series_data(
obj.main_dttm_col = "ds" obj.main_dttm_col = "ds"
obj.database = database obj.database = database
obj.filter_select_enabled = True obj.filter_select_enabled = True
db.session.commit()
obj.fetch_metadata() obj.fetch_metadata()
tbl = obj tbl = obj

View File

@ -64,5 +64,4 @@ def load_sf_population_polygons(
tbl.description = "Population density of San Francisco" tbl.description = "Population density of San Francisco"
tbl.database = database tbl.database = database
tbl.filter_select_enabled = True tbl.filter_select_enabled = True
db.session.commit()
tbl.fetch_metadata() tbl.fetch_metadata()

View File

@ -14,9 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import textwrap import textwrap
from sqlalchemy import inspect from sqlalchemy import inspect
@ -1274,4 +1272,3 @@ def load_supported_charts_dashboard() -> None:
dash.dashboard_title = "Supported Charts Dashboard" dash.dashboard_title = "Supported Charts Dashboard"
dash.position_json = json.dumps(pos, indent=2) dash.position_json = json.dumps(pos, indent=2)
dash.slug = DASH_SLUG dash.slug = DASH_SLUG
db.session.commit()

View File

@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Loads datasets, dashboards and slices in a new superset instance"""
import textwrap import textwrap
from superset import db from superset import db
@ -558,4 +556,3 @@ def load_tabbed_dashboard(_: bool = False) -> None:
dash.slices = slices dash.slices = slices
dash.dashboard_title = "Tabbed Dashboard" dash.dashboard_title = "Tabbed Dashboard"
dash.slug = slug dash.slug = slug
db.session.commit()

Some files were not shown because too many files have changed in this diff Show More