fix(reports): make name unique between alerts and reports (#12196)
* fix(reports): make name unique between alerts and reports * add missing migration * make it work for mySQL and PG only (yet) * fixing sqlite crazy unique drop * fixing sqlite missing one col
This commit is contained in:
parent
b75a1ec71e
commit
74f3faf1cd
|
|
@ -0,0 +1,104 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""alert reports shared uniqueness
|
||||
|
||||
Revision ID: c878781977c6
|
||||
Revises: 73fd22e742ab
|
||||
Create Date: 2020-12-23 11:34:53.882200
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c878781977c6"
|
||||
down_revision = "73fd22e742ab"
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.mysql.base import MySQLDialect
|
||||
from sqlalchemy.dialects.postgresql.base import PGDialect
|
||||
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
from superset.utils.core import generic_find_uq_constraint_name
|
||||
|
||||
report_schedule = sa.Table(
|
||||
"report_schedule",
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("type", sa.String(length=50), nullable=False),
|
||||
sa.Column("name", sa.String(length=150), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("context_markdown", sa.Text(), nullable=True),
|
||||
sa.Column("active", sa.Boolean(), default=True, nullable=True),
|
||||
sa.Column("crontab", sa.String(length=1000), nullable=False),
|
||||
sa.Column("sql", sa.Text(), nullable=True),
|
||||
sa.Column("chart_id", sa.Integer(), nullable=True),
|
||||
sa.Column("dashboard_id", sa.Integer(), nullable=True),
|
||||
sa.Column("database_id", sa.Integer(), nullable=True),
|
||||
sa.Column("last_eval_dttm", sa.DateTime(), nullable=True),
|
||||
sa.Column("last_state", sa.String(length=50), nullable=True),
|
||||
sa.Column("last_value", sa.Float(), nullable=True),
|
||||
sa.Column("last_value_row_json", sa.Text(), nullable=True),
|
||||
sa.Column("validator_type", sa.String(length=100), nullable=True),
|
||||
sa.Column("validator_config_json", sa.Text(), default="{}", nullable=True),
|
||||
sa.Column("log_retention", sa.Integer(), nullable=True, default=90),
|
||||
sa.Column("grace_period", sa.Integer(), nullable=True, default=60 * 60 * 4),
|
||||
sa.Column("working_timeout", sa.Integer(), nullable=True, default=60 * 60 * 1),
|
||||
# Audit Mixin
|
||||
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["chart_id"], ["slices.id"]),
|
||||
sa.ForeignKeyConstraint(["dashboard_id"], ["dashboards.id"]),
|
||||
sa.ForeignKeyConstraint(["database_id"], ["dbs.id"]),
|
||||
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"]),
|
||||
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def upgrade():
|
||||
bind = op.get_bind()
|
||||
|
||||
if not isinstance(bind.dialect, SQLiteDialect):
|
||||
op.drop_constraint("uq_report_schedule_name", "report_schedule", type_="unique")
|
||||
|
||||
if isinstance(bind.dialect, MySQLDialect):
|
||||
op.drop_index(
|
||||
op.f("name"), table_name="report_schedule",
|
||||
)
|
||||
|
||||
if isinstance(bind.dialect, PGDialect):
|
||||
op.drop_constraint(
|
||||
"report_schedule_name_key", "report_schedule", type_="unique"
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
"uq_report_schedule_name_type", "report_schedule", ["name", "type"]
|
||||
)
|
||||
|
||||
else:
|
||||
with op.batch_alter_table(
|
||||
"report_schedule", copy_from=report_schedule
|
||||
) as batch_op:
|
||||
batch_op.create_unique_constraint(
|
||||
"uq_report_schedule_name_type", ["name", "type"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
pass
|
||||
|
|
@ -92,9 +92,11 @@ class ReportSchedule(Model, AuditMixinNullable):
|
|||
"""
|
||||
|
||||
__tablename__ = "report_schedule"
|
||||
__table_args__ = (UniqueConstraint("name", "type"),)
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
type = Column(String(50), nullable=False)
|
||||
name = Column(String(150), nullable=False, unique=True)
|
||||
name = Column(String(150), nullable=False)
|
||||
description = Column(Text)
|
||||
context_markdown = Column(Text)
|
||||
active = Column(Boolean, default=True, index=True)
|
||||
|
|
|
|||
|
|
@ -64,8 +64,10 @@ class CreateReportScheduleCommand(BaseReportScheduleCommand):
|
|||
if not report_type:
|
||||
exceptions.append(ReportScheduleRequiredTypeValidationError())
|
||||
|
||||
# Validate name uniqueness
|
||||
if not ReportScheduleDAO.validate_update_uniqueness(name):
|
||||
# Validate name type uniqueness
|
||||
if report_type and not ReportScheduleDAO.validate_update_uniqueness(
|
||||
name, report_type
|
||||
):
|
||||
exceptions.append(ReportScheduleNameUniquenessValidationError())
|
||||
|
||||
# validate relation by report type
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class ReportScheduleWorkingTimeoutError(CommandException):
|
|||
|
||||
class ReportScheduleNameUniquenessValidationError(ValidationError):
|
||||
"""
|
||||
Marshmallow validation error for Report Schedule name already exists
|
||||
Marshmallow validation error for Report Schedule name and type already exists
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -80,15 +80,16 @@ class UpdateReportScheduleCommand(BaseReportScheduleCommand):
|
|||
):
|
||||
self._properties["last_state"] = ReportState.NOOP
|
||||
|
||||
# Validate name uniqueness
|
||||
if not ReportScheduleDAO.validate_update_uniqueness(
|
||||
name, report_schedule_id=self._model_id
|
||||
):
|
||||
exceptions.append(ReportScheduleNameUniquenessValidationError())
|
||||
|
||||
# validate relation by report type
|
||||
if not report_type:
|
||||
report_type = self._model.type
|
||||
|
||||
# Validate name type uniqueness
|
||||
if not ReportScheduleDAO.validate_update_uniqueness(
|
||||
name, report_type, report_schedule_id=self._model_id
|
||||
):
|
||||
exceptions.append(ReportScheduleNameUniquenessValidationError())
|
||||
|
||||
if report_type == ReportScheduleType.ALERT:
|
||||
database_id = self._properties.get("database")
|
||||
# If database_id was sent let's validate it exists
|
||||
|
|
|
|||
|
|
@ -111,17 +111,20 @@ class ReportScheduleDAO(BaseDAO):
|
|||
|
||||
@staticmethod
|
||||
def validate_update_uniqueness(
|
||||
name: str, report_schedule_id: Optional[int] = None
|
||||
name: str, report_type: str, report_schedule_id: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Validate if this name is unique.
|
||||
Validate if this name and type is unique.
|
||||
|
||||
:param name: The report schedule name
|
||||
:param report_type: The report schedule type
|
||||
:param report_schedule_id: The report schedule current id
|
||||
(only for validating on updates)
|
||||
:return: bool
|
||||
"""
|
||||
query = db.session.query(ReportSchedule).filter(ReportSchedule.name == name)
|
||||
query = db.session.query(ReportSchedule).filter(
|
||||
ReportSchedule.name == name, ReportSchedule.type == report_type
|
||||
)
|
||||
if report_schedule_id:
|
||||
query = query.filter(ReportSchedule.id != report_schedule_id)
|
||||
return not db.session.query(query.exists()).scalar()
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from croniter import croniter
|
||||
from marshmallow import fields, Schema, validate
|
||||
from marshmallow import fields, Schema, validate, validates_schema
|
||||
from marshmallow.validate import Length, ValidationError
|
||||
|
||||
from superset.models.reports import (
|
||||
|
|
@ -170,6 +170,16 @@ class ReportSchedulePostSchema(Schema):
|
|||
|
||||
recipients = fields.List(fields.Nested(ReportRecipientSchema))
|
||||
|
||||
@validates_schema
|
||||
def validate_report_references( # pylint: disable=unused-argument,no-self-use
|
||||
self, data: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
if data["type"] == ReportScheduleType.REPORT:
|
||||
if "database" in data:
|
||||
raise ValidationError(
|
||||
{"database": ["Database reference is not allowed on a report"]}
|
||||
)
|
||||
|
||||
|
||||
class ReportSchedulePutSchema(Schema):
|
||||
type = fields.String(
|
||||
|
|
|
|||
|
|
@ -468,6 +468,46 @@ class TestReportSchedulesApi(SupersetTestCase):
|
|||
data = json.loads(rv.data.decode("utf-8"))
|
||||
assert data == {"message": {"name": ["Name must be unique"]}}
|
||||
|
||||
# Check that uniqueness is composed by name and type
|
||||
report_schedule_data = {
|
||||
"type": ReportScheduleType.REPORT,
|
||||
"name": "name3",
|
||||
"description": "description",
|
||||
"crontab": "0 9 * * *",
|
||||
"chart": chart.id,
|
||||
}
|
||||
uri = "api/v1/report/"
|
||||
rv = self.client.post(uri, json=report_schedule_data)
|
||||
assert rv.status_code == 201
|
||||
data = json.loads(rv.data.decode("utf-8"))
|
||||
|
||||
# Rollback changes
|
||||
created_model = db.session.query(ReportSchedule).get(data.get("id"))
|
||||
db.session.delete(created_model)
|
||||
db.session.commit()
|
||||
|
||||
@pytest.mark.usefixtures("create_report_schedules")
|
||||
def test_create_report_schedule_schema(self):
|
||||
"""
|
||||
ReportSchedule Api: Test create report schedule schema check
|
||||
"""
|
||||
self.login(username="admin")
|
||||
chart = db.session.query(Slice).first()
|
||||
example_db = get_example_database()
|
||||
|
||||
# Check that a report does not have a database reference
|
||||
report_schedule_data = {
|
||||
"type": ReportScheduleType.REPORT,
|
||||
"name": "name3",
|
||||
"description": "description",
|
||||
"crontab": "0 9 * * *",
|
||||
"chart": chart.id,
|
||||
"database": example_db.id,
|
||||
}
|
||||
uri = "api/v1/report/"
|
||||
rv = self.client.post(uri, json=report_schedule_data)
|
||||
assert rv.status_code == 400
|
||||
|
||||
@pytest.mark.usefixtures("create_report_schedules")
|
||||
def test_create_report_schedule_chart_dash_validation(self):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue