chore(dao): Add generic type for better type checking (#24465)

This commit is contained in:
John Bodley 2023-06-21 09:30:07 -07:00 committed by GitHub
parent d5f88c18f6
commit 92e2ee9d07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 68 additions and 64 deletions

View File

@ -38,6 +38,8 @@ class DeleteAnnotationCommand(BaseCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
annotation = AnnotationDAO.delete(self._model)
except DAODeleteFailedError as ex:

View File

@ -45,6 +45,8 @@ class UpdateAnnotationCommand(BaseCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
annotation = AnnotationDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:

View File

@ -39,6 +39,8 @@ class DeleteAnnotationLayerCommand(BaseCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
annotation_layer = AnnotationLayerDAO.delete(self._model)
except DAODeleteFailedError as ex:

View File

@ -42,6 +42,8 @@ class UpdateAnnotationLayerCommand(BaseCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
annotation_layer = AnnotationLayerDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:

View File

@ -56,6 +56,8 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now()

View File

@ -30,7 +30,7 @@ METADATA_FILE_NAME = "metadata.yaml"
class ExportModelsCommand(BaseCommand):
dao: type[BaseDAO] = BaseDAO
dao: type[BaseDAO[Model]] = BaseDAO
not_found: type[CommandException] = CommandException
def __init__(self, model_ids: list[int], export_related: bool = True):

View File

@ -27,9 +27,7 @@ from superset.models.annotations import Annotation, AnnotationLayer
logger = logging.getLogger(__name__)
class AnnotationDAO(BaseDAO):
model_cls = Annotation
class AnnotationDAO(BaseDAO[Annotation]):
@staticmethod
def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
@ -64,9 +62,7 @@ class AnnotationDAO(BaseDAO):
return not db.session.query(query.exists()).scalar()
class AnnotationLayerDAO(BaseDAO):
model_cls = AnnotationLayer
class AnnotationLayerDAO(BaseDAO[AnnotationLayer]):
@staticmethod
def bulk_delete(
models: Optional[list[AnnotationLayer]], commit: bool = True

View File

@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=isinstance-second-argument-not-valid-type
from typing import Any, Optional, Union
from typing import Any, Generic, get_args, Optional, TypeVar, Union
from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
@ -31,8 +30,10 @@ from superset.daos.exceptions import (
)
from superset.extensions import db
T = TypeVar("T", bound=Model) # pylint: disable=invalid-name
class BaseDAO:
class BaseDAO(Generic[T]):
"""
Base DAO, implement base CRUD sqlalchemy operations
"""
@ -48,6 +49,11 @@ class BaseDAO:
"""
id_column_name = "id"
def __init_subclass__(cls) -> None: # pylint: disable=arguments-differ
cls.model_cls = get_args(
cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member
)[0]
@classmethod
def find_by_id(
cls,
@ -78,7 +84,7 @@ class BaseDAO:
model_ids: Union[list[str], list[int]],
session: Session = None,
skip_base_filter: bool = False,
) -> list[Model]:
) -> list[T]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
@ -95,7 +101,7 @@ class BaseDAO:
return query.all()
@classmethod
def find_all(cls) -> list[Model]:
def find_all(cls) -> list[T]:
"""
Get all that fit the `base_filter`
"""
@ -108,7 +114,7 @@ class BaseDAO:
return query.all()
@classmethod
def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]:
def find_one_or_none(cls, **filter_by: Any) -> Optional[T]:
"""
Get the first that fit the `base_filter`
"""
@ -121,7 +127,7 @@ class BaseDAO:
return query.filter_by(**filter_by).one_or_none()
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
def create(cls, properties: dict[str, Any], commit: bool = True) -> T:
"""
Generic for creating models
:raises: DAOCreateFailedError
@ -141,17 +147,13 @@ class BaseDAO:
return model
@classmethod
def save(cls, instance_model: Model, commit: bool = True) -> Model:
def save(cls, instance_model: T, commit: bool = True) -> None:
"""
Generic for saving models
:raises: DAOCreateFailedError
"""
if cls.model_cls is None:
raise DAOConfigError()
if not isinstance(instance_model, cls.model_cls):
raise DAOCreateFailedError(
"the instance model is not a type of the model class"
)
try:
db.session.add(instance_model)
if commit:
@ -159,12 +161,9 @@ class BaseDAO:
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
return instance_model
@classmethod
def update(
cls, model: Model, properties: dict[str, Any], commit: bool = True
) -> Model:
def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T:
"""
Generic update a model
:raises: DAOCreateFailedError
@ -181,7 +180,7 @@ class BaseDAO:
return model
@classmethod
def delete(cls, model: Model, commit: bool = True) -> Model:
def delete(cls, model: T, commit: bool = True) -> T:
"""
Generic delete a model
:raises: DAODeleteFailedError
@ -196,7 +195,7 @@ class BaseDAO:
return model
@classmethod
def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
def bulk_delete(cls, models: list[T], commit: bool = True) -> None:
try:
for model in models:
cls.delete(model, False)

View File

@ -34,8 +34,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ChartDAO(BaseDAO):
model_cls = Slice
class ChartDAO(BaseDAO[Slice]):
base_filter = ChartFilter
@staticmethod

View File

@ -27,9 +27,7 @@ from superset.models.core import CssTemplate
logger = logging.getLogger(__name__)
class CssTemplateDAO(BaseDAO):
model_cls = CssTemplate
class CssTemplateDAO(BaseDAO[CssTemplate]):
@staticmethod
def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []

View File

@ -49,8 +49,7 @@ from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
logger = logging.getLogger(__name__)
class DashboardDAO(BaseDAO):
model_cls = Dashboard
class DashboardDAO(BaseDAO[Dashboard]):
base_filter = DashboardAccessFilter
@classmethod
@ -379,8 +378,7 @@ class DashboardDAO(BaseDAO):
db.session.commit()
class EmbeddedDashboardDAO(BaseDAO):
model_cls = EmbeddedDashboard
class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
# There isn't really a regular scenario where we would rather get Embedded by id
id_column_name = "uuid"
@ -407,9 +405,7 @@ class EmbeddedDashboardDAO(BaseDAO):
raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.")
class FilterSetDAO(BaseDAO):
model_cls = FilterSet
class FilterSetDAO(BaseDAO[FilterSet]):
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
if cls.model_cls is None:

View File

@ -31,8 +31,7 @@ from superset.utils.ssh_tunnel import unmask_password_info
logger = logging.getLogger(__name__)
class DatabaseDAO(BaseDAO):
model_cls = Database
class DatabaseDAO(BaseDAO[Database]):
base_filter = DatabaseFilter
@classmethod
@ -138,9 +137,7 @@ class DatabaseDAO(BaseDAO):
return ssh_tunnel
class SSHTunnelDAO(BaseDAO):
model_cls = SSHTunnel
class SSHTunnelDAO(BaseDAO[SSHTunnel]):
@classmethod
def update(
cls,

View File

@ -31,8 +31,7 @@ from superset.views.base import DatasourceFilter
logger = logging.getLogger(__name__)
class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
model_cls = SqlaTable
class DatasetDAO(BaseDAO[SqlaTable]): # pylint: disable=too-many-public-methods
base_filter = DatasourceFilter
@staticmethod
@ -151,7 +150,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
model: SqlaTable,
properties: dict[str, Any],
commit: bool = True,
) -> Optional[SqlaTable]:
) -> SqlaTable:
"""
Updates a Dataset model on the metadata DB
"""
@ -397,9 +396,9 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
)
class DatasetColumnDAO(BaseDAO):
model_cls = TableColumn
class DatasetColumnDAO(BaseDAO[TableColumn]):
pass
class DatasetMetricDAO(BaseDAO):
model_cls = SqlMetric
class DatasetMetricDAO(BaseDAO[SqlMetric]):
pass

View File

@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
class DatasourceDAO(BaseDAO):
class DatasourceDAO(BaseDAO[Datasource]):
sources: dict[Union[DatasourceType, str], type[Datasource]] = {
DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query,

View File

@ -30,9 +30,7 @@ from superset.utils.core import get_user_id
from superset.utils.dates import datetime_to_epoch
class LogDAO(BaseDAO):
model_cls = Log
class LogDAO(BaseDAO[Log]):
@staticmethod
def get_recent_activity(
actions: list[str],

View File

@ -35,8 +35,7 @@ from superset.utils.dates import now_as_float
logger = logging.getLogger(__name__)
class QueryDAO(BaseDAO):
model_cls = Query
class QueryDAO(BaseDAO[Query]):
base_filter = QueryFilter
@staticmethod
@ -104,8 +103,7 @@ class QueryDAO(BaseDAO):
db.session.commit()
class SavedQueryDAO(BaseDAO):
model_cls = SavedQuery
class SavedQueryDAO(BaseDAO[SavedQuery]):
base_filter = SavedQueryFilter
@staticmethod

View File

@ -42,8 +42,7 @@ logger = logging.getLogger(__name__)
REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER = "Notification sent with error"
class ReportScheduleDAO(BaseDAO):
model_cls = ReportSchedule
class ReportScheduleDAO(BaseDAO[ReportSchedule]):
base_filter = ReportScheduleFilter
@staticmethod

View File

@ -19,5 +19,5 @@ from superset.connectors.sqla.models import RowLevelSecurityFilter
from superset.daos.base import BaseDAO
class RLSDAO(BaseDAO):
model_cls = RowLevelSecurityFilter
class RLSDAO(BaseDAO[RowLevelSecurityFilter]):
pass

View File

@ -31,8 +31,7 @@ from superset.tags.models import get_tag, ObjectTypes, Tag, TaggedObject, TagTyp
logger = logging.getLogger(__name__)
class TagDAO(BaseDAO):
model_cls = Tag
class TagDAO(BaseDAO[Tag]):
# base_filter = TagAccessFilter
@staticmethod

View File

@ -44,6 +44,8 @@ class DeleteDashboardCommand(BaseCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
dashboard = DashboardDAO.delete(self._model)
except DAODeleteFailedError as ex:

View File

@ -48,6 +48,8 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
if self._properties.get("json_metadata"):

View File

@ -36,8 +36,10 @@ class DeleteFilterSetCommand(BaseFilterSetCommand):
self._filter_set_id = filter_set_id
def run(self) -> Model:
self.validate()
assert self._filter_set
try:
self.validate()
return FilterSetDAO.delete(self._filter_set, commit=True)
except DAODeleteFailedError as err:
raise FilterSetDeleteFailedError(str(self._filter_set_id), "") from err

View File

@ -39,6 +39,8 @@ class UpdateFilterSetCommand(BaseFilterSetCommand):
def run(self) -> Model:
try:
self.validate()
assert self._filter_set
if (
OWNER_TYPE_FIELD in self._properties
and self._properties[OWNER_TYPE_FIELD] == "Dashboard"

View File

@ -42,6 +42,8 @@ class DeleteDatabaseCommand(BaseCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
database = DatabaseDAO.delete(self._model)
except DAODeleteFailedError as ex:

View File

@ -42,6 +42,8 @@ class DeleteSSHTunnelCommand(BaseCommand):
if not is_feature_enabled("SSH_TUNNELING"):
raise SSHTunnelingNotEnabledError()
self.validate()
assert self._model
try:
ssh_tunnel = SSHTunnelDAO.delete(self._model)
except DAODeleteFailedError as ex:

View File

@ -41,6 +41,8 @@ class DeleteReportScheduleCommand(BaseCommand):
def run(self) -> Model:
self.validate()
assert self._model
try:
report_schedule = ReportScheduleDAO.delete(self._model)
except DAODeleteFailedError as ex:

View File

@ -41,6 +41,8 @@ class UpdateRLSRuleCommand(BaseCommand):
def run(self) -> Any:
self.validate()
assert self._model
try:
rule = RLSDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex: