From 92e2ee9d0745f8717adea493636451751db0eb08 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Wed, 21 Jun 2023 09:30:07 -0700 Subject: [PATCH] chore(dao): Add generic type for better type checking (#24465) --- .../annotations/commands/delete.py | 2 ++ .../annotations/commands/update.py | 2 ++ superset/annotation_layers/commands/delete.py | 2 ++ superset/annotation_layers/commands/update.py | 2 ++ superset/charts/commands/update.py | 2 ++ superset/commands/export/models.py | 2 +- superset/daos/annotation.py | 8 ++--- superset/daos/base.py | 35 +++++++++---------- superset/daos/chart.py | 3 +- superset/daos/css.py | 4 +-- superset/daos/dashboard.py | 10 ++---- superset/daos/database.py | 7 ++-- superset/daos/dataset.py | 13 ++++--- superset/daos/datasource.py | 2 +- superset/daos/log.py | 4 +-- superset/daos/query.py | 6 ++-- superset/daos/report.py | 3 +- superset/daos/security.py | 4 +-- superset/daos/tag.py | 3 +- superset/dashboards/commands/delete.py | 2 ++ superset/dashboards/commands/update.py | 2 ++ .../dashboards/filter_sets/commands/delete.py | 4 ++- .../dashboards/filter_sets/commands/update.py | 2 ++ superset/databases/commands/delete.py | 2 ++ .../databases/ssh_tunnel/commands/delete.py | 2 ++ superset/reports/commands/delete.py | 2 ++ .../row_level_security/commands/update.py | 2 ++ 27 files changed, 68 insertions(+), 64 deletions(-) diff --git a/superset/annotation_layers/annotations/commands/delete.py b/superset/annotation_layers/annotations/commands/delete.py index b86ae997a..2af01f57f 100644 --- a/superset/annotation_layers/annotations/commands/delete.py +++ b/superset/annotation_layers/annotations/commands/delete.py @@ -38,6 +38,8 @@ class DeleteAnnotationCommand(BaseCommand): def run(self) -> Model: self.validate() + assert self._model + try: annotation = AnnotationDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py index 03797a555..76287d24a 100644 --- a/superset/annotation_layers/annotations/commands/update.py +++ b/superset/annotation_layers/annotations/commands/update.py @@ -45,6 +45,8 @@ class UpdateAnnotationCommand(BaseCommand): def run(self) -> Model: self.validate() + assert self._model + try: annotation = AnnotationDAO.update(self._model, self._properties) except DAOUpdateFailedError as ex: diff --git a/superset/annotation_layers/commands/delete.py b/superset/annotation_layers/commands/delete.py index 0692d4dd8..1af4242dc 100644 --- a/superset/annotation_layers/commands/delete.py +++ b/superset/annotation_layers/commands/delete.py @@ -39,6 +39,8 @@ class DeleteAnnotationLayerCommand(BaseCommand): def run(self) -> Model: self.validate() + assert self._model + try: annotation_layer = AnnotationLayerDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/annotation_layers/commands/update.py b/superset/annotation_layers/commands/update.py index ca3a28841..e7f6963e8 100644 --- a/superset/annotation_layers/commands/update.py +++ b/superset/annotation_layers/commands/update.py @@ -42,6 +42,8 @@ class UpdateAnnotationLayerCommand(BaseCommand): def run(self) -> Model: self.validate() + assert self._model + try: annotation_layer = AnnotationLayerDAO.update(self._model, self._properties) except DAOUpdateFailedError as ex: diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index 9a5b4e1f2..32fd49e7c 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -56,6 +56,8 @@ class UpdateChartCommand(UpdateMixin, BaseCommand): def run(self) -> Model: self.validate() + assert self._model + try: if self._properties.get("query_context_generation") is None: self._properties["last_saved_at"] = datetime.now() diff --git a/superset/commands/export/models.py b/superset/commands/export/models.py index 27f4572af..61532d4a0 100644 --- a/superset/commands/export/models.py +++ b/superset/commands/export/models.py @@ -30,7 +30,7 @@ METADATA_FILE_NAME = "metadata.yaml" class ExportModelsCommand(BaseCommand): - dao: type[BaseDAO] = BaseDAO + dao: type[BaseDAO[Model]] = BaseDAO not_found: type[CommandException] = CommandException def __init__(self, model_ids: list[int], export_related: bool = True): diff --git a/superset/daos/annotation.py b/superset/daos/annotation.py index 171a708fa..2df336647 100644 --- a/superset/daos/annotation.py +++ b/superset/daos/annotation.py @@ -27,9 +27,7 @@ from superset.models.annotations import Annotation, AnnotationLayer logger = logging.getLogger(__name__) -class AnnotationDAO(BaseDAO): - model_cls = Annotation - +class AnnotationDAO(BaseDAO[Annotation]): @staticmethod def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] @@ -64,9 +62,7 @@ class AnnotationDAO(BaseDAO): return not db.session.query(query.exists()).scalar() -class AnnotationLayerDAO(BaseDAO): - model_cls = AnnotationLayer - +class AnnotationLayerDAO(BaseDAO[AnnotationLayer]): @staticmethod def bulk_delete( models: Optional[list[AnnotationLayer]], commit: bool = True diff --git a/superset/daos/base.py b/superset/daos/base.py index 6465e5b17..c0758f51d 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=isinstance-second-argument-not-valid-type -from typing import Any, Optional, Union +from typing import Any, Generic, get_args, Optional, TypeVar, Union from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model @@ -31,8 +30,10 @@ from superset.daos.exceptions import ( ) from superset.extensions import db +T = TypeVar("T", bound=Model) # pylint: disable=invalid-name -class BaseDAO: + +class BaseDAO(Generic[T]): """ Base DAO, implement base CRUD sqlalchemy operations """ @@ -48,6 +49,11 @@ class BaseDAO: """ id_column_name = "id" + def __init_subclass__(cls) -> None: # pylint: disable=arguments-differ + cls.model_cls = get_args( + cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member + )[0] + @classmethod def find_by_id( cls, @@ -78,7 +84,7 @@ class BaseDAO: model_ids: Union[list[str], list[int]], session: Session = None, skip_base_filter: bool = False, - ) -> list[Model]: + ) -> list[T]: """ Find a List of models by a list of ids, if defined applies `base_filter` """ @@ -95,7 +101,7 @@ class BaseDAO: return query.all() @classmethod - def find_all(cls) -> list[Model]: + def find_all(cls) -> list[T]: """ Get all that fit the `base_filter` """ @@ -108,7 +114,7 @@ class BaseDAO: return query.all() @classmethod - def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]: + def find_one_or_none(cls, **filter_by: Any) -> Optional[T]: """ Get the first that fit the `base_filter` """ @@ -121,7 +127,7 @@ class BaseDAO: return query.filter_by(**filter_by).one_or_none() @classmethod - def create(cls, properties: dict[str, Any], commit: bool = True) -> Model: + def create(cls, properties: dict[str, Any], commit: bool = True) -> T: """ Generic for creating models :raises: DAOCreateFailedError @@ -141,17 +147,13 @@ class BaseDAO: return model @classmethod - def save(cls, instance_model: Model, commit: bool = True) -> Model: + def save(cls, instance_model: T, commit: bool = True) -> None: """ Generic for saving models :raises: DAOCreateFailedError """ if cls.model_cls is None: raise DAOConfigError() - if not isinstance(instance_model, cls.model_cls): - raise DAOCreateFailedError( - "the instance model is not a type of the model class" - ) try: db.session.add(instance_model) if commit: @@ -159,12 +161,9 @@ class BaseDAO: except SQLAlchemyError as ex: # pragma: no cover db.session.rollback() raise DAOCreateFailedError(exception=ex) from ex - return instance_model @classmethod - def update( - cls, model: Model, properties: dict[str, Any], commit: bool = True - ) -> Model: + def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T: """ Generic update a model :raises: DAOCreateFailedError @@ -181,7 +180,7 @@ class BaseDAO: return model @classmethod - def delete(cls, model: Model, commit: bool = True) -> Model: + def delete(cls, model: T, commit: bool = True) -> T: """ Generic delete a model :raises: DAODeleteFailedError @@ -196,7 +195,7 @@ class BaseDAO: return model @classmethod - def bulk_delete(cls, models: list[Model], commit: bool = True) -> None: + def bulk_delete(cls, models: list[T], commit: bool = True) -> None: try: for model in models: cls.delete(model, False) diff --git a/superset/daos/chart.py b/superset/daos/chart.py index 838d93abd..1a1396502 100644 --- a/superset/daos/chart.py +++ b/superset/daos/chart.py @@ -34,8 +34,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ChartDAO(BaseDAO): - model_cls = Slice +class ChartDAO(BaseDAO[Slice]): base_filter = ChartFilter @staticmethod diff --git a/superset/daos/css.py b/superset/daos/css.py index 224277a40..3a1cbe8fd 100644 --- a/superset/daos/css.py +++ b/superset/daos/css.py @@ -27,9 +27,7 @@ from superset.models.core import CssTemplate logger = logging.getLogger(__name__) -class CssTemplateDAO(BaseDAO): - model_cls = CssTemplate - +class CssTemplateDAO(BaseDAO[CssTemplate]): @staticmethod def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] diff --git a/superset/daos/dashboard.py b/superset/daos/dashboard.py index 1e31591e1..1650711d5 100644 --- a/superset/daos/dashboard.py +++ b/superset/daos/dashboard.py @@ -49,8 +49,7 @@ from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes logger = logging.getLogger(__name__) -class DashboardDAO(BaseDAO): - model_cls = Dashboard +class DashboardDAO(BaseDAO[Dashboard]): base_filter = DashboardAccessFilter @classmethod @@ -379,8 +378,7 @@ class DashboardDAO(BaseDAO): db.session.commit() -class EmbeddedDashboardDAO(BaseDAO): - model_cls = EmbeddedDashboard +class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]): # There isn't really a regular scenario where we would rather get Embedded by id id_column_name = "uuid" @@ -407,9 +405,7 @@ class EmbeddedDashboardDAO(BaseDAO): raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.") -class FilterSetDAO(BaseDAO): - model_cls = FilterSet - +class FilterSetDAO(BaseDAO[FilterSet]): @classmethod def create(cls, properties: dict[str, Any], commit: bool = True) -> Model: if cls.model_cls is None: diff --git a/superset/daos/database.py b/superset/daos/database.py index 569568472..0a3cb65b2 100644 --- a/superset/daos/database.py +++ b/superset/daos/database.py @@ -31,8 +31,7 @@ from superset.utils.ssh_tunnel import unmask_password_info logger = logging.getLogger(__name__) -class DatabaseDAO(BaseDAO): - model_cls = Database +class DatabaseDAO(BaseDAO[Database]): base_filter = DatabaseFilter @classmethod @@ -138,9 +137,7 @@ class DatabaseDAO(BaseDAO): return ssh_tunnel -class SSHTunnelDAO(BaseDAO): - model_cls = SSHTunnel - +class SSHTunnelDAO(BaseDAO[SSHTunnel]): @classmethod def update( cls, diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index 3937e6c31..4634a7e46 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -31,8 +31,7 @@ from superset.views.base import DatasourceFilter logger = logging.getLogger(__name__) -class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods - model_cls = SqlaTable +class DatasetDAO(BaseDAO[SqlaTable]): # pylint: disable=too-many-public-methods base_filter = DatasourceFilter @staticmethod @@ -151,7 +150,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods model: SqlaTable, properties: dict[str, Any], commit: bool = True, - ) -> Optional[SqlaTable]: + ) -> SqlaTable: """ Updates a Dataset model on the metadata DB """ @@ -397,9 +396,9 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods ) -class DatasetColumnDAO(BaseDAO): - model_cls = TableColumn +class DatasetColumnDAO(BaseDAO[TableColumn]): + pass -class DatasetMetricDAO(BaseDAO): - model_cls = SqlMetric +class DatasetMetricDAO(BaseDAO[SqlMetric]): + pass diff --git a/superset/daos/datasource.py b/superset/daos/datasource.py index 684106161..2bdf4ca21 100644 --- a/superset/daos/datasource.py +++ b/superset/daos/datasource.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery] -class DatasourceDAO(BaseDAO): +class DatasourceDAO(BaseDAO[Datasource]): sources: dict[Union[DatasourceType, str], type[Datasource]] = { DatasourceType.TABLE: SqlaTable, DatasourceType.QUERY: Query, diff --git a/superset/daos/log.py b/superset/daos/log.py index 81767a48c..002c3f230 100644 --- a/superset/daos/log.py +++ b/superset/daos/log.py @@ -30,9 +30,7 @@ from superset.utils.core import get_user_id from superset.utils.dates import datetime_to_epoch -class LogDAO(BaseDAO): - model_cls = Log - +class LogDAO(BaseDAO[Log]): @staticmethod def get_recent_activity( actions: list[str], diff --git a/superset/daos/query.py b/superset/daos/query.py index 8996e27a3..80b5d1ad4 100644 --- a/superset/daos/query.py +++ b/superset/daos/query.py @@ -35,8 +35,7 @@ from superset.utils.dates import now_as_float logger = logging.getLogger(__name__) -class QueryDAO(BaseDAO): - model_cls = Query +class QueryDAO(BaseDAO[Query]): base_filter = QueryFilter @staticmethod @@ -104,8 +103,7 @@ class QueryDAO(BaseDAO): db.session.commit() -class SavedQueryDAO(BaseDAO): - model_cls = SavedQuery +class SavedQueryDAO(BaseDAO[SavedQuery]): base_filter = SavedQueryFilter @staticmethod diff --git a/superset/daos/report.py b/superset/daos/report.py index 4f8d914ad..70a87a645 100644 --- a/superset/daos/report.py +++ b/superset/daos/report.py @@ -42,8 +42,7 @@ logger = logging.getLogger(__name__) REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER = "Notification sent with error" -class ReportScheduleDAO(BaseDAO): - model_cls = ReportSchedule +class ReportScheduleDAO(BaseDAO[ReportSchedule]): base_filter = ReportScheduleFilter @staticmethod diff --git a/superset/daos/security.py b/superset/daos/security.py index a435f224a..392d741e3 100644 --- a/superset/daos/security.py +++ b/superset/daos/security.py @@ -19,5 +19,5 @@ from superset.connectors.sqla.models import RowLevelSecurityFilter from superset.daos.base import BaseDAO -class RLSDAO(BaseDAO): - model_cls = RowLevelSecurityFilter +class RLSDAO(BaseDAO[RowLevelSecurityFilter]): + pass diff --git a/superset/daos/tag.py b/superset/daos/tag.py index ec991edb1..90b0134ca 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -31,8 +31,7 @@ from superset.tags.models import get_tag, ObjectTypes, Tag, TaggedObject, TagTyp logger = logging.getLogger(__name__) -class TagDAO(BaseDAO): - model_cls = Tag +class TagDAO(BaseDAO[Tag]): # base_filter = TagAccessFilter @staticmethod diff --git a/superset/dashboards/commands/delete.py b/superset/dashboards/commands/delete.py index f774b92a5..1f5eb4ae3 100644 --- a/superset/dashboards/commands/delete.py +++ b/superset/dashboards/commands/delete.py @@ -44,6 +44,8 @@ class DeleteDashboardCommand(BaseCommand): def run(self) -> Model: self.validate() + assert self._model + try: dashboard = DashboardDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py index cd9c07e0f..c880eebe8 100644 --- a/superset/dashboards/commands/update.py +++ b/superset/dashboards/commands/update.py @@ -48,6 +48,8 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): def run(self) -> Model: self.validate() + assert self._model + try: dashboard = DashboardDAO.update(self._model, self._properties, commit=False) if self._properties.get("json_metadata"): diff --git a/superset/dashboards/filter_sets/commands/delete.py b/superset/dashboards/filter_sets/commands/delete.py index 93f438339..c05835424 100644 --- a/superset/dashboards/filter_sets/commands/delete.py +++ b/superset/dashboards/filter_sets/commands/delete.py @@ -36,8 +36,10 @@ class DeleteFilterSetCommand(BaseFilterSetCommand): self._filter_set_id = filter_set_id def run(self) -> Model: + self.validate() + assert self._filter_set + try: - self.validate() return FilterSetDAO.delete(self._filter_set, commit=True) except DAODeleteFailedError as err: raise FilterSetDeleteFailedError(str(self._filter_set_id), "") from err diff --git a/superset/dashboards/filter_sets/commands/update.py b/superset/dashboards/filter_sets/commands/update.py index eecaa34ae..a63c8d46f 100644 --- a/superset/dashboards/filter_sets/commands/update.py +++ b/superset/dashboards/filter_sets/commands/update.py @@ -39,6 +39,8 @@ class UpdateFilterSetCommand(BaseFilterSetCommand): def run(self) -> Model: try: self.validate() + assert self._filter_set + if ( OWNER_TYPE_FIELD in self._properties and self._properties[OWNER_TYPE_FIELD] == "Dashboard" diff --git a/superset/databases/commands/delete.py b/superset/databases/commands/delete.py index b8eb3f6e5..95d212e29 100644 --- a/superset/databases/commands/delete.py +++ b/superset/databases/commands/delete.py @@ -42,6 +42,8 @@ class DeleteDatabaseCommand(BaseCommand): def run(self) -> Model: self.validate() + assert self._model + try: database = DatabaseDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/databases/ssh_tunnel/commands/delete.py b/superset/databases/ssh_tunnel/commands/delete.py index 910df35a1..375c496f2 100644 --- a/superset/databases/ssh_tunnel/commands/delete.py +++ b/superset/databases/ssh_tunnel/commands/delete.py @@ -42,6 +42,8 @@ class DeleteSSHTunnelCommand(BaseCommand): if not is_feature_enabled("SSH_TUNNELING"): raise SSHTunnelingNotEnabledError() self.validate() + assert self._model + try: ssh_tunnel = SSHTunnelDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/reports/commands/delete.py b/superset/reports/commands/delete.py index 3f7e4e5d2..f52d96f7f 100644 --- a/superset/reports/commands/delete.py +++ b/superset/reports/commands/delete.py @@ -41,6 +41,8 @@ class DeleteReportScheduleCommand(BaseCommand): def run(self) -> Model: self.validate() + assert self._model + try: report_schedule = ReportScheduleDAO.delete(self._model) except DAODeleteFailedError as ex: diff --git a/superset/row_level_security/commands/update.py b/superset/row_level_security/commands/update.py index d44aa3efa..bc5ef368b 100644 --- a/superset/row_level_security/commands/update.py +++ b/superset/row_level_security/commands/update.py @@ -41,6 +41,8 @@ class UpdateRLSRuleCommand(BaseCommand): def run(self) -> Any: self.validate() + assert self._model + try: rule = RLSDAO.update(self._model, self._properties) except DAOUpdateFailedError as ex: