chore(dao): Replace save/overwrite with create/update respectively (#24467)

This commit is contained in:
John Bodley 2023-08-11 12:55:39 -07:00 committed by GitHub
parent a3d72e0ec7
commit ed0d288ccd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 196 additions and 194 deletions

View File

@ -42,11 +42,10 @@ class CreateAnnotationCommand(BaseCommand):
def run(self) -> Model:
self.validate()
try:
annotation = AnnotationDAO.create(self._properties)
return AnnotationDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationCreateFailedError() from ex
return annotation
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -39,11 +39,10 @@ class CreateAnnotationLayerCommand(BaseCommand):
def run(self) -> Model:
self.validate()
try:
annotation_layer = AnnotationLayerDAO.create(self._properties)
return AnnotationLayerDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationLayerCreateFailedError() from ex
return annotation_layer
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -47,11 +47,10 @@ class CreateChartCommand(CreateMixin, BaseCommand):
try:
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
chart = ChartDAO.create(self._properties)
return ChartDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ChartCreateFailedError() from ex
return chart
def validate(self) -> None:
exceptions = []

View File

@ -25,7 +25,6 @@ from sqlalchemy.exc import SQLAlchemyError, StatementError
from sqlalchemy.orm import Session
from superset.daos.exceptions import (
DAOConfigError,
DAOCreateFailedError,
DAODeleteFailedError,
DAOUpdateFailedError,
@ -130,57 +129,72 @@ class BaseDAO(Generic[T]):
return query.filter_by(**filter_by).one_or_none()
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> T:
def create(
cls,
item: T | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> T:
"""
Generic for creating models
:raises: DAOCreateFailedError
"""
if cls.model_cls is None:
raise DAOConfigError()
model = cls.model_cls() # pylint: disable=not-callable
for key, value in properties.items():
setattr(model, key, value)
try:
db.session.add(model)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
return model
Create an object from the specified item and/or attributes.
@classmethod
def save(cls, instance_model: T, commit: bool = True) -> None:
:param item: The object to create
:param attributes: The attributes associated with the object to create
:param commit: Whether to commit the transaction
:raises DAOCreateFailedError: If the creation failed
"""
Generic for saving models
:raises: DAOCreateFailedError
"""
if cls.model_cls is None:
raise DAOConfigError()
if not item:
item = cls.model_cls() # type: ignore # pylint: disable=not-callable
if attributes:
for key, value in attributes.items():
setattr(item, key, value)
try:
db.session.add(instance_model)
db.session.add(item)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
return item # type: ignore
@classmethod
def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T:
def update(
cls,
item: T | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> T:
"""
Generic update a model
:raises: DAOCreateFailedError
Update an object from the specified item and/or attributes.
:param item: The object to update
:param attributes: The attributes associated with the object to update
:param commit: Whether to commit the transaction
:raises DAOUpdateFailedError: If the updating failed
"""
for key, value in properties.items():
setattr(model, key, value)
if not item:
item = cls.model_cls() # type: ignore # pylint: disable=not-callable
if attributes:
for key, value in attributes.items():
setattr(item, key, value)
try:
db.session.merge(model)
db.session.merge(item)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOUpdateFailedError(exception=ex) from ex
return model
return item # type: ignore
@classmethod
def delete(cls, items: T | list[T], commit: bool = True) -> None:

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=arguments-renamed
from __future__ import annotations
import logging
@ -54,18 +53,6 @@ class ChartDAO(BaseDAO[Slice]):
db.session.rollback()
raise ex
@staticmethod
def save(slc: Slice, commit: bool = True) -> None:
db.session.add(slc)
if commit:
db.session.commit()
@staticmethod
def overwrite(slc: Slice, commit: bool = True) -> None:
db.session.merge(slc)
if commit:
db.session.commit()
@staticmethod
def favorited_ids(charts: list[Slice]) -> list[FavStar]:
ids = [chart.id for chart in charts]

View File

@ -22,13 +22,11 @@ from datetime import datetime
from typing import Any
from flask import g
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError
from superset import is_feature_enabled, security_manager
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAOConfigError, DAOCreateFailedError
from superset.dashboards.commands.exceptions import (
DashboardAccessDeniedError,
DashboardForbiddenError,
@ -403,35 +401,40 @@ class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
return embedded
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Any:
def create(
cls,
item: EmbeddedDashboardDAO | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Any:
"""
Use EmbeddedDashboardDAO.upsert() instead.
At least, until we are ok with more than one embedded instance per dashboard.
At least, until we are ok with more than one embedded item per dashboard.
"""
raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.")
class FilterSetDAO(BaseDAO[FilterSet]):
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
if cls.model_cls is None:
raise DAOConfigError()
model = FilterSet()
setattr(model, NAME_FIELD, properties[NAME_FIELD])
setattr(model, JSON_METADATA_FIELD, properties[JSON_METADATA_FIELD])
setattr(model, DESCRIPTION_FIELD, properties.get(DESCRIPTION_FIELD, None))
setattr(
model,
OWNER_ID_FIELD,
properties.get(OWNER_ID_FIELD, properties[DASHBOARD_ID_FIELD]),
)
setattr(model, OWNER_TYPE_FIELD, properties[OWNER_TYPE_FIELD])
setattr(model, DASHBOARD_ID_FIELD, properties[DASHBOARD_ID_FIELD])
try:
db.session.add(model)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError() from ex
return model
def create(
cls,
item: FilterSet | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> FilterSet:
if not item:
item = FilterSet()
if attributes:
setattr(item, NAME_FIELD, attributes[NAME_FIELD])
setattr(item, JSON_METADATA_FIELD, attributes[JSON_METADATA_FIELD])
setattr(item, DESCRIPTION_FIELD, attributes.get(DESCRIPTION_FIELD, None))
setattr(
item,
OWNER_ID_FIELD,
attributes.get(OWNER_ID_FIELD, attributes[DASHBOARD_ID_FIELD]),
)
setattr(item, OWNER_TYPE_FIELD, attributes[OWNER_TYPE_FIELD])
setattr(item, DASHBOARD_ID_FIELD, attributes[DASHBOARD_ID_FIELD])
return super().create(item, commit=commit)

View File

@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from typing import Any, Optional
from typing import Any
from superset.daos.base import BaseDAO
from superset.databases.filters import DatabaseFilter
@ -37,8 +39,8 @@ class DatabaseDAO(BaseDAO[Database]):
@classmethod
def update(
cls,
model: Database,
properties: dict[str, Any],
item: Database | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Database:
"""
@ -50,13 +52,14 @@ class DatabaseDAO(BaseDAO[Database]):
The masked values should be unmasked before the database is updated.
"""
if "encrypted_extra" in properties:
properties["encrypted_extra"] = model.db_engine_spec.unmask_encrypted_extra(
model.encrypted_extra,
properties["encrypted_extra"],
if item and attributes and "encrypted_extra" in attributes:
attributes["encrypted_extra"] = item.db_engine_spec.unmask_encrypted_extra(
item.encrypted_extra,
attributes["encrypted_extra"],
)
return super().update(model, properties, commit)
return super().update(item, attributes, commit)
@staticmethod
def validate_uniqueness(database_name: str) -> bool:
@ -74,7 +77,7 @@ class DatabaseDAO(BaseDAO[Database]):
return not db.session.query(database_query.exists()).scalar()
@staticmethod
def get_database_by_name(database_name: str) -> Optional[Database]:
def get_database_by_name(database_name: str) -> Database | None:
return (
db.session.query(Database)
.filter(Database.database_name == database_name)
@ -129,7 +132,7 @@ class DatabaseDAO(BaseDAO[Database]):
}
@classmethod
def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]:
def get_ssh_tunnel(cls, database_id: int) -> SSHTunnel | None:
ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database_id)
@ -143,8 +146,8 @@ class SSHTunnelDAO(BaseDAO[SSHTunnel]):
@classmethod
def update(
cls,
model: SSHTunnel,
properties: dict[str, Any],
item: SSHTunnel | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> SSHTunnel:
"""
@ -156,7 +159,9 @@ class SSHTunnelDAO(BaseDAO[SSHTunnel]):
The masked values should be unmasked before the ssh tunnel is updated.
"""
# ID cannot be updated so we remove it if present in the payload
properties.pop("id", None)
properties = unmask_password_info(properties, model)
return super().update(model, properties, commit)
if item and attributes:
attributes.pop("id", None)
attributes = unmask_password_info(attributes, item)
return super().update(item, attributes, commit)

View File

@ -150,26 +150,27 @@ class DatasetDAO(BaseDAO[SqlaTable]): # pylint: disable=too-many-public-methods
@classmethod
def update(
cls,
model: SqlaTable,
properties: dict[str, Any],
item: SqlaTable | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> SqlaTable:
"""
Updates a Dataset model on the metadata DB
"""
if "columns" in properties:
cls.update_columns(
model,
properties.pop("columns"),
commit=commit,
override_columns=bool(properties.get("override_columns")),
)
if item and attributes:
if "columns" in attributes:
cls.update_columns(
item,
attributes.pop("columns"),
commit=commit,
override_columns=bool(attributes.get("override_columns")),
)
if "metrics" in properties:
cls.update_metrics(model, properties.pop("metrics"), commit=commit)
if "metrics" in attributes:
cls.update_metrics(item, attributes.pop("metrics"), commit=commit)
return super().update(model, properties, commit=commit)
return super().update(item, attributes, commit=commit)
@classmethod
def update_columns(
@ -316,7 +317,7 @@ class DatasetDAO(BaseDAO[SqlaTable]): # pylint: disable=too-many-public-methods
"""
Creates a Dataset model on the metadata DB
"""
return DatasetColumnDAO.create(properties, commit=commit)
return DatasetColumnDAO.create(attributes=properties, commit=commit)
@classmethod
def delete_column(cls, model: TableColumn, commit: bool = True) -> None:
@ -358,7 +359,7 @@ class DatasetDAO(BaseDAO[SqlaTable]): # pylint: disable=too-many-public-methods
"""
Creates a Dataset model on the metadata DB
"""
return DatasetMetricDAO.create(properties, commit=commit)
return DatasetMetricDAO.create(attributes=properties, commit=commit)
@classmethod
def delete(

View File

@ -36,7 +36,7 @@ class DAOUpdateFailedError(DAOException):
DAO Update failed
"""
message = "Updated failed"
message = "Update failed"
class DAODeleteFailedError(DAOException):
@ -47,14 +47,6 @@ class DAODeleteFailedError(DAOException):
message = "Delete failed"
class DAOConfigError(DAOException):
"""
DAO is miss configured
"""
message = "DAO is not configured correctly missing model definition"
class DatasourceTypeNotSupportedError(DAOException):
"""
DAO datasource query source type is not supported

View File

@ -21,12 +21,11 @@ import logging
from datetime import datetime
from typing import Any
from flask_appbuilder import Model
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAOCreateFailedError, DAODeleteFailedError
from superset.daos.exceptions import DAODeleteFailedError
from superset.extensions import db
from superset.reports.filters import ReportScheduleFilter
from superset.reports.models import (
@ -135,67 +134,74 @@ class ReportScheduleDAO(BaseDAO[ReportSchedule]):
return found_id is None or found_id == expect_id
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> ReportSchedule:
"""
create a report schedule and nested recipients
:raises: DAOCreateFailedError
"""
try:
model = ReportSchedule()
for key, value in properties.items():
if key != "recipients":
setattr(model, key, value)
recipients = properties.get("recipients", [])
for recipient in recipients:
model.recipients.append( # pylint: disable=no-member
ReportRecipients(
type=recipient["type"],
recipient_config_json=json.dumps(
recipient["recipient_config_json"]
),
)
)
db.session.add(model)
if commit:
db.session.commit()
return model
except SQLAlchemyError as ex:
db.session.rollback()
raise DAOCreateFailedError(str(ex)) from ex
@classmethod
def update(
cls, model: Model, properties: dict[str, Any], commit: bool = True
def create(
cls,
item: ReportSchedule | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> ReportSchedule:
"""
create a report schedule and nested recipients
:raises: DAOCreateFailedError
Create a report schedule with nested recipients.
:param item: The object to create
:param attributes: The attributes associated with the object to create
:param commit: Whether to commit the transaction
:raises: DAOCreateFailedError: If the creation failed
"""
try:
for key, value in properties.items():
if key != "recipients":
setattr(model, key, value)
if "recipients" in properties:
recipients = properties["recipients"]
model.recipients = [
# TODO(john-bodley): Determine why we need special handling for recipients.
if not item:
item = ReportSchedule()
if attributes:
if recipients := attributes.pop("recipients", None):
attributes["recipients"] = [
ReportRecipients(
type=recipient["type"],
recipient_config_json=json.dumps(
recipient["recipient_config_json"]
),
report_schedule=model,
report_schedule=item,
)
for recipient in recipients
]
db.session.merge(model)
if commit:
db.session.commit()
return model
except SQLAlchemyError as ex:
db.session.rollback()
raise DAOCreateFailedError(str(ex)) from ex
return super().create(item, attributes, commit)
@classmethod
def update(
cls,
item: ReportSchedule | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> ReportSchedule:
"""
Update a report schedule with nested recipients.
:param item: The object to update
:param attributes: The attributes associated with the object to update
:param commit: Whether to commit the transaction
:raises: DAOUpdateFailedError: If the updation failed
"""
# TODO(john-bodley): Determine why we need special handling for recipients.
if not item:
item = ReportSchedule()
if attributes:
if recipients := attributes.pop("recipients", None):
attributes["recipients"] = [
ReportRecipients(
type=recipient["type"],
recipient_config_json=json.dumps(
recipient["recipient_config_json"]
),
report_schedule=item,
)
for recipient in recipients
]
return super().update(item, attributes, commit)
@staticmethod
def find_active(session: Session | None = None) -> list[ReportSchedule]:

View File

@ -40,7 +40,7 @@ class CreateDashboardCommand(CreateMixin, BaseCommand):
def run(self) -> Model:
self.validate()
try:
dashboard = DashboardDAO.create(self._properties, commit=False)
dashboard = DashboardDAO.create(attributes=self._properties, commit=False)
dashboard = DashboardDAO.update_charts_owners(dashboard, commit=True)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)

View File

@ -47,8 +47,7 @@ class CreateFilterSetCommand(BaseFilterSetCommand):
def run(self) -> Model:
self.validate()
self._properties[DASHBOARD_ID_FIELD] = self._dashboard.id
filter_set = FilterSetDAO.create(self._properties, commit=True)
return filter_set
return FilterSetDAO.create(attributes=self._properties, commit=True)
def validate(self) -> None:
self._validate_filterset_dashboard_exists()

View File

@ -77,7 +77,7 @@ class CreateDatabaseCommand(BaseCommand):
)
try:
database = DatabaseDAO.create(self._properties, commit=False)
database = DatabaseDAO.create(attributes=self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = None

View File

@ -46,7 +46,7 @@ class CreateSSHTunnelCommand(BaseCommand):
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
db.session.begin_nested()
self.validate()
tunnel = SSHTunnelDAO.create(self._properties, commit=False)
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
except DAOCreateFailedError as ex:
# Rollback nested transaction
db.session.rollback()
@ -56,8 +56,6 @@ class CreateSSHTunnelCommand(BaseCommand):
db.session.rollback()
raise ex
return tunnel
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER

View File

@ -44,7 +44,7 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
self.validate()
try:
# Creates SqlaTable (Dataset)
dataset = DatasetDAO.create(self._properties, commit=False)
dataset = DatasetDAO.create(attributes=self._properties, commit=False)
# Updates columns and metrics from the datase
dataset.fetch_metadata(commit=False)

View File

@ -66,8 +66,8 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
if self._model:
try:
dataset = DatasetDAO.update(
model=self._model,
properties=self._properties,
self._model,
attributes=self._properties,
)
return dataset
except DAOUpdateFailedError as ex:

View File

@ -52,11 +52,10 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
def run(self) -> ReportSchedule:
self.validate()
try:
report_schedule = ReportScheduleDAO.create(self._properties)
return ReportScheduleDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ReportScheduleCreateFailedError() from ex
return report_schedule
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@ -49,6 +49,8 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
report_schedule = ReportScheduleDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:

View File

@ -39,13 +39,11 @@ class CreateRLSRuleCommand(BaseCommand):
def run(self) -> Any:
self.validate()
try:
rule = RLSDAO.create(self._properties)
return RLSDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ex
return rule
def validate(self) -> None:
roles = populate_roles(self._roles)
tables = (

View File

@ -177,7 +177,7 @@ class ExecuteSqlCommand(BaseCommand):
def _save_new_query(self, query: Query) -> None:
try:
self._query_dao.save(query)
self._query_dao.create(query)
except DAOCreateFailedError as ex:
raise SqlLabException(
self._execution_context,

View File

@ -695,11 +695,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
slc.query_context = query_context
if action == "saveas" and slice_add_perm:
ChartDAO.save(slc)
ChartDAO.create(slc)
msg = _("Chart [{}] has been saved").format(slc.slice_name)
flash(msg, "success")
elif action == "overwrite" and slice_overwrite_perm:
ChartDAO.overwrite(slc)
ChartDAO.update(slc)
msg = _("Chart [{}] has been overwritten").format(slc.slice_name)
flash(msg, "success")

View File

@ -27,15 +27,16 @@ def test_create_ssh_tunnel():
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
properties = {
"database_id": db.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
result = SSHTunnelDAO.create(properties, commit=False)
result = SSHTunnelDAO.create(
attributes={
"database_id": db.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
},
commit=False,
)
assert result is not None
assert isinstance(result, SSHTunnel)