chore(dao): Add generic type for better type checking (#24465)
This commit is contained in:
parent
d5f88c18f6
commit
92e2ee9d07
|
|
@ -38,6 +38,8 @@ class DeleteAnnotationCommand(BaseCommand):
|
|||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
annotation = AnnotationDAO.delete(self._model)
|
||||
except DAODeleteFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -45,6 +45,8 @@ class UpdateAnnotationCommand(BaseCommand):
|
|||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
annotation = AnnotationDAO.update(self._model, self._properties)
|
||||
except DAOUpdateFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class DeleteAnnotationLayerCommand(BaseCommand):
|
|||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
annotation_layer = AnnotationLayerDAO.delete(self._model)
|
||||
except DAODeleteFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ class UpdateAnnotationLayerCommand(BaseCommand):
|
|||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
annotation_layer = AnnotationLayerDAO.update(self._model, self._properties)
|
||||
except DAOUpdateFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -56,6 +56,8 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
|
|||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
if self._properties.get("query_context_generation") is None:
|
||||
self._properties["last_saved_at"] = datetime.now()
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ METADATA_FILE_NAME = "metadata.yaml"
|
|||
|
||||
|
||||
class ExportModelsCommand(BaseCommand):
|
||||
dao: type[BaseDAO] = BaseDAO
|
||||
dao: type[BaseDAO[Model]] = BaseDAO
|
||||
not_found: type[CommandException] = CommandException
|
||||
|
||||
def __init__(self, model_ids: list[int], export_related: bool = True):
|
||||
|
|
|
|||
|
|
@ -27,9 +27,7 @@ from superset.models.annotations import Annotation, AnnotationLayer
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotationDAO(BaseDAO):
|
||||
model_cls = Annotation
|
||||
|
||||
class AnnotationDAO(BaseDAO[Annotation]):
|
||||
@staticmethod
|
||||
def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
|
||||
item_ids = [model.id for model in models] if models else []
|
||||
|
|
@ -64,9 +62,7 @@ class AnnotationDAO(BaseDAO):
|
|||
return not db.session.query(query.exists()).scalar()
|
||||
|
||||
|
||||
class AnnotationLayerDAO(BaseDAO):
|
||||
model_cls = AnnotationLayer
|
||||
|
||||
class AnnotationLayerDAO(BaseDAO[AnnotationLayer]):
|
||||
@staticmethod
|
||||
def bulk_delete(
|
||||
models: Optional[list[AnnotationLayer]], commit: bool = True
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=isinstance-second-argument-not-valid-type
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Generic, get_args, Optional, TypeVar, Union
|
||||
|
||||
from flask_appbuilder.models.filters import BaseFilter
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
|
@ -31,8 +30,10 @@ from superset.daos.exceptions import (
|
|||
)
|
||||
from superset.extensions import db
|
||||
|
||||
T = TypeVar("T", bound=Model) # pylint: disable=invalid-name
|
||||
|
||||
class BaseDAO:
|
||||
|
||||
class BaseDAO(Generic[T]):
|
||||
"""
|
||||
Base DAO, implement base CRUD sqlalchemy operations
|
||||
"""
|
||||
|
|
@ -48,6 +49,11 @@ class BaseDAO:
|
|||
"""
|
||||
id_column_name = "id"
|
||||
|
||||
def __init_subclass__(cls) -> None: # pylint: disable=arguments-differ
|
||||
cls.model_cls = get_args(
|
||||
cls.__orig_bases__[0] # type: ignore # pylint: disable=no-member
|
||||
)[0]
|
||||
|
||||
@classmethod
|
||||
def find_by_id(
|
||||
cls,
|
||||
|
|
@ -78,7 +84,7 @@ class BaseDAO:
|
|||
model_ids: Union[list[str], list[int]],
|
||||
session: Session = None,
|
||||
skip_base_filter: bool = False,
|
||||
) -> list[Model]:
|
||||
) -> list[T]:
|
||||
"""
|
||||
Find a List of models by a list of ids, if defined applies `base_filter`
|
||||
"""
|
||||
|
|
@ -95,7 +101,7 @@ class BaseDAO:
|
|||
return query.all()
|
||||
|
||||
@classmethod
|
||||
def find_all(cls) -> list[Model]:
|
||||
def find_all(cls) -> list[T]:
|
||||
"""
|
||||
Get all that fit the `base_filter`
|
||||
"""
|
||||
|
|
@ -108,7 +114,7 @@ class BaseDAO:
|
|||
return query.all()
|
||||
|
||||
@classmethod
|
||||
def find_one_or_none(cls, **filter_by: Any) -> Optional[Model]:
|
||||
def find_one_or_none(cls, **filter_by: Any) -> Optional[T]:
|
||||
"""
|
||||
Get the first that fit the `base_filter`
|
||||
"""
|
||||
|
|
@ -121,7 +127,7 @@ class BaseDAO:
|
|||
return query.filter_by(**filter_by).one_or_none()
|
||||
|
||||
@classmethod
|
||||
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
|
||||
def create(cls, properties: dict[str, Any], commit: bool = True) -> T:
|
||||
"""
|
||||
Generic for creating models
|
||||
:raises: DAOCreateFailedError
|
||||
|
|
@ -141,17 +147,13 @@ class BaseDAO:
|
|||
return model
|
||||
|
||||
@classmethod
|
||||
def save(cls, instance_model: Model, commit: bool = True) -> Model:
|
||||
def save(cls, instance_model: T, commit: bool = True) -> None:
|
||||
"""
|
||||
Generic for saving models
|
||||
:raises: DAOCreateFailedError
|
||||
"""
|
||||
if cls.model_cls is None:
|
||||
raise DAOConfigError()
|
||||
if not isinstance(instance_model, cls.model_cls):
|
||||
raise DAOCreateFailedError(
|
||||
"the instance model is not a type of the model class"
|
||||
)
|
||||
try:
|
||||
db.session.add(instance_model)
|
||||
if commit:
|
||||
|
|
@ -159,12 +161,9 @@ class BaseDAO:
|
|||
except SQLAlchemyError as ex: # pragma: no cover
|
||||
db.session.rollback()
|
||||
raise DAOCreateFailedError(exception=ex) from ex
|
||||
return instance_model
|
||||
|
||||
@classmethod
|
||||
def update(
|
||||
cls, model: Model, properties: dict[str, Any], commit: bool = True
|
||||
) -> Model:
|
||||
def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T:
|
||||
"""
|
||||
Generic update a model
|
||||
:raises: DAOCreateFailedError
|
||||
|
|
@ -181,7 +180,7 @@ class BaseDAO:
|
|||
return model
|
||||
|
||||
@classmethod
|
||||
def delete(cls, model: Model, commit: bool = True) -> Model:
|
||||
def delete(cls, model: T, commit: bool = True) -> T:
|
||||
"""
|
||||
Generic delete a model
|
||||
:raises: DAODeleteFailedError
|
||||
|
|
@ -196,7 +195,7 @@ class BaseDAO:
|
|||
return model
|
||||
|
||||
@classmethod
|
||||
def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
|
||||
def bulk_delete(cls, models: list[T], commit: bool = True) -> None:
|
||||
try:
|
||||
for model in models:
|
||||
cls.delete(model, False)
|
||||
|
|
|
|||
|
|
@ -34,8 +34,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChartDAO(BaseDAO):
|
||||
model_cls = Slice
|
||||
class ChartDAO(BaseDAO[Slice]):
|
||||
base_filter = ChartFilter
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -27,9 +27,7 @@ from superset.models.core import CssTemplate
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CssTemplateDAO(BaseDAO):
|
||||
model_cls = CssTemplate
|
||||
|
||||
class CssTemplateDAO(BaseDAO[CssTemplate]):
|
||||
@staticmethod
|
||||
def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
|
||||
item_ids = [model.id for model in models] if models else []
|
||||
|
|
|
|||
|
|
@ -49,8 +49,7 @@ from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DashboardDAO(BaseDAO):
|
||||
model_cls = Dashboard
|
||||
class DashboardDAO(BaseDAO[Dashboard]):
|
||||
base_filter = DashboardAccessFilter
|
||||
|
||||
@classmethod
|
||||
|
|
@ -379,8 +378,7 @@ class DashboardDAO(BaseDAO):
|
|||
db.session.commit()
|
||||
|
||||
|
||||
class EmbeddedDashboardDAO(BaseDAO):
|
||||
model_cls = EmbeddedDashboard
|
||||
class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]):
|
||||
# There isn't really a regular scenario where we would rather get Embedded by id
|
||||
id_column_name = "uuid"
|
||||
|
||||
|
|
@ -407,9 +405,7 @@ class EmbeddedDashboardDAO(BaseDAO):
|
|||
raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.")
|
||||
|
||||
|
||||
class FilterSetDAO(BaseDAO):
|
||||
model_cls = FilterSet
|
||||
|
||||
class FilterSetDAO(BaseDAO[FilterSet]):
|
||||
@classmethod
|
||||
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
|
||||
if cls.model_cls is None:
|
||||
|
|
|
|||
|
|
@ -31,8 +31,7 @@ from superset.utils.ssh_tunnel import unmask_password_info
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseDAO(BaseDAO):
|
||||
model_cls = Database
|
||||
class DatabaseDAO(BaseDAO[Database]):
|
||||
base_filter = DatabaseFilter
|
||||
|
||||
@classmethod
|
||||
|
|
@ -138,9 +137,7 @@ class DatabaseDAO(BaseDAO):
|
|||
return ssh_tunnel
|
||||
|
||||
|
||||
class SSHTunnelDAO(BaseDAO):
|
||||
model_cls = SSHTunnel
|
||||
|
||||
class SSHTunnelDAO(BaseDAO[SSHTunnel]):
|
||||
@classmethod
|
||||
def update(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -31,8 +31,7 @@ from superset.views.base import DatasourceFilter
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
|
||||
model_cls = SqlaTable
|
||||
class DatasetDAO(BaseDAO[SqlaTable]): # pylint: disable=too-many-public-methods
|
||||
base_filter = DatasourceFilter
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -151,7 +150,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
|
|||
model: SqlaTable,
|
||||
properties: dict[str, Any],
|
||||
commit: bool = True,
|
||||
) -> Optional[SqlaTable]:
|
||||
) -> SqlaTable:
|
||||
"""
|
||||
Updates a Dataset model on the metadata DB
|
||||
"""
|
||||
|
|
@ -397,9 +396,9 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
|
|||
)
|
||||
|
||||
|
||||
class DatasetColumnDAO(BaseDAO):
|
||||
model_cls = TableColumn
|
||||
class DatasetColumnDAO(BaseDAO[TableColumn]):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetMetricDAO(BaseDAO):
|
||||
model_cls = SqlMetric
|
||||
class DatasetMetricDAO(BaseDAO[SqlMetric]):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
|||
Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
|
||||
|
||||
|
||||
class DatasourceDAO(BaseDAO):
|
||||
class DatasourceDAO(BaseDAO[Datasource]):
|
||||
sources: dict[Union[DatasourceType, str], type[Datasource]] = {
|
||||
DatasourceType.TABLE: SqlaTable,
|
||||
DatasourceType.QUERY: Query,
|
||||
|
|
|
|||
|
|
@ -30,9 +30,7 @@ from superset.utils.core import get_user_id
|
|||
from superset.utils.dates import datetime_to_epoch
|
||||
|
||||
|
||||
class LogDAO(BaseDAO):
|
||||
model_cls = Log
|
||||
|
||||
class LogDAO(BaseDAO[Log]):
|
||||
@staticmethod
|
||||
def get_recent_activity(
|
||||
actions: list[str],
|
||||
|
|
|
|||
|
|
@ -35,8 +35,7 @@ from superset.utils.dates import now_as_float
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueryDAO(BaseDAO):
|
||||
model_cls = Query
|
||||
class QueryDAO(BaseDAO[Query]):
|
||||
base_filter = QueryFilter
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -104,8 +103,7 @@ class QueryDAO(BaseDAO):
|
|||
db.session.commit()
|
||||
|
||||
|
||||
class SavedQueryDAO(BaseDAO):
|
||||
model_cls = SavedQuery
|
||||
class SavedQueryDAO(BaseDAO[SavedQuery]):
|
||||
base_filter = SavedQueryFilter
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -42,8 +42,7 @@ logger = logging.getLogger(__name__)
|
|||
REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER = "Notification sent with error"
|
||||
|
||||
|
||||
class ReportScheduleDAO(BaseDAO):
|
||||
model_cls = ReportSchedule
|
||||
class ReportScheduleDAO(BaseDAO[ReportSchedule]):
|
||||
base_filter = ReportScheduleFilter
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -19,5 +19,5 @@ from superset.connectors.sqla.models import RowLevelSecurityFilter
|
|||
from superset.daos.base import BaseDAO
|
||||
|
||||
|
||||
class RLSDAO(BaseDAO):
|
||||
model_cls = RowLevelSecurityFilter
|
||||
class RLSDAO(BaseDAO[RowLevelSecurityFilter]):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -31,8 +31,7 @@ from superset.tags.models import get_tag, ObjectTypes, Tag, TaggedObject, TagTyp
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TagDAO(BaseDAO):
|
||||
model_cls = Tag
|
||||
class TagDAO(BaseDAO[Tag]):
|
||||
# base_filter = TagAccessFilter
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@ class DeleteDashboardCommand(BaseCommand):
|
|||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
dashboard = DashboardDAO.delete(self._model)
|
||||
except DAODeleteFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -48,6 +48,8 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
|
|||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
dashboard = DashboardDAO.update(self._model, self._properties, commit=False)
|
||||
if self._properties.get("json_metadata"):
|
||||
|
|
|
|||
|
|
@ -36,8 +36,10 @@ class DeleteFilterSetCommand(BaseFilterSetCommand):
|
|||
self._filter_set_id = filter_set_id
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._filter_set
|
||||
|
||||
try:
|
||||
self.validate()
|
||||
return FilterSetDAO.delete(self._filter_set, commit=True)
|
||||
except DAODeleteFailedError as err:
|
||||
raise FilterSetDeleteFailedError(str(self._filter_set_id), "") from err
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class UpdateFilterSetCommand(BaseFilterSetCommand):
|
|||
def run(self) -> Model:
|
||||
try:
|
||||
self.validate()
|
||||
assert self._filter_set
|
||||
|
||||
if (
|
||||
OWNER_TYPE_FIELD in self._properties
|
||||
and self._properties[OWNER_TYPE_FIELD] == "Dashboard"
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ class DeleteDatabaseCommand(BaseCommand):
|
|||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
database = DatabaseDAO.delete(self._model)
|
||||
except DAODeleteFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -42,6 +42,8 @@ class DeleteSSHTunnelCommand(BaseCommand):
|
|||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
ssh_tunnel = SSHTunnelDAO.delete(self._model)
|
||||
except DAODeleteFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -41,6 +41,8 @@ class DeleteReportScheduleCommand(BaseCommand):
|
|||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
report_schedule = ReportScheduleDAO.delete(self._model)
|
||||
except DAODeleteFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -41,6 +41,8 @@ class UpdateRLSRuleCommand(BaseCommand):
|
|||
|
||||
def run(self) -> Any:
|
||||
self.validate()
|
||||
assert self._model
|
||||
|
||||
try:
|
||||
rule = RLSDAO.update(self._model, self._properties)
|
||||
except DAOUpdateFailedError as ex:
|
||||
|
|
|
|||
Loading…
Reference in New Issue