feat: new reports scheduler (#11711)
* feat(reports): scheduler and delivery system * working version * improvements and fix grace_period * add tests and fix bugs * fix report API test * test MySQL test fail * delete-orphans * fix MySQL tests * address comments * lint
This commit is contained in:
parent
501b9d47c5
commit
f27ebc4be5
|
|
@ -20,6 +20,7 @@ from flask_appbuilder.models.filters import BaseFilter
|
|||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.dao.exceptions import (
|
||||
DAOConfigError,
|
||||
|
|
@ -46,13 +47,14 @@ class BaseDAO:
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def find_by_id(cls, model_id: int) -> Model:
|
||||
def find_by_id(cls, model_id: int, session: Session = None) -> Model:
|
||||
"""
|
||||
Find a model by id, if defined applies `base_filter`
|
||||
"""
|
||||
query = db.session.query(cls.model_cls)
|
||||
session = session or db.session
|
||||
query = session.query(cls.model_cls)
|
||||
if cls.base_filter:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
data_model = SQLAInterface(cls.model_cls, session)
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
"id", data_model
|
||||
).apply(query, None)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,9 @@ class SupersetException(Exception):
|
|||
status = 500
|
||||
message = ""
|
||||
|
||||
def __init__(self, message: str = "", exception: Optional[Exception] = None):
|
||||
def __init__(
|
||||
self, message: str = "", exception: Optional[Exception] = None,
|
||||
) -> None:
|
||||
if message:
|
||||
self.message = message
|
||||
self._exception = exception
|
||||
|
|
|
|||
|
|
@ -60,7 +60,9 @@ class ReportRecipientType(str, enum.Enum):
|
|||
|
||||
class ReportLogState(str, enum.Enum):
|
||||
SUCCESS = "Success"
|
||||
WORKING = "Working"
|
||||
ERROR = "Error"
|
||||
NOOP = "Not triggered"
|
||||
|
||||
|
||||
class ReportEmailFormat(str, enum.Enum):
|
||||
|
|
@ -175,6 +177,6 @@ class ReportExecutionLog(Model): # pylint: disable=too-few-public-methods
|
|||
)
|
||||
report_schedule = relationship(
|
||||
ReportSchedule,
|
||||
backref=backref("logs", cascade="all,delete"),
|
||||
backref=backref("logs", cascade="all,delete,delete-orphan"),
|
||||
foreign_keys=[report_schedule_id],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,101 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
import logging
|
||||
from operator import eq, ge, gt, le, lt, ne
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset import jinja_context
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.models.reports import ReportSchedule, ReportScheduleValidatorType
|
||||
from superset.reports.commands.exceptions import (
|
||||
AlertQueryInvalidTypeError,
|
||||
AlertQueryMultipleColumnsError,
|
||||
AlertQueryMultipleRowsError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OPERATOR_FUNCTIONS = {">=": ge, ">": gt, "<=": le, "<": lt, "==": eq, "!=": ne}
|
||||
|
||||
|
||||
class AlertCommand(BaseCommand):
|
||||
def __init__(self, report_schedule: ReportSchedule):
|
||||
self._report_schedule = report_schedule
|
||||
self._result: Optional[float] = None
|
||||
|
||||
def run(self) -> bool:
|
||||
self.validate()
|
||||
|
||||
if self._report_schedule.validator_type == ReportScheduleValidatorType.NOT_NULL:
|
||||
self._report_schedule.last_value_row_json = self._result
|
||||
return self._result not in (0, None, np.nan)
|
||||
self._report_schedule.last_value = self._result
|
||||
operator = json.loads(self._report_schedule.validator_config_json)["op"]
|
||||
threshold = json.loads(self._report_schedule.validator_config_json)["threshold"]
|
||||
return OPERATOR_FUNCTIONS[operator](self._result, threshold)
|
||||
|
||||
def _validate_not_null(self, rows: np.recarray) -> None:
|
||||
self._result = rows[0][1]
|
||||
|
||||
def _validate_operator(self, rows: np.recarray) -> None:
|
||||
# check if query return more then one row
|
||||
if len(rows) > 1:
|
||||
raise AlertQueryMultipleRowsError(
|
||||
message=_(
|
||||
"Alert query returned more then one row. %s rows returned"
|
||||
% len(rows),
|
||||
)
|
||||
)
|
||||
# check if query returned more then one column
|
||||
if len(rows[0]) > 2:
|
||||
raise AlertQueryMultipleColumnsError(
|
||||
_(
|
||||
"Alert query returned more then one column. %s columns returned"
|
||||
% len(rows[0])
|
||||
)
|
||||
)
|
||||
if rows[0][1] is None:
|
||||
return
|
||||
try:
|
||||
# Check if it's float or if we can convert it
|
||||
self._result = float(rows[0][1])
|
||||
return
|
||||
except (AssertionError, TypeError, ValueError):
|
||||
raise AlertQueryInvalidTypeError()
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
Validate the query result as a Pandas DataFrame
|
||||
"""
|
||||
sql_template = jinja_context.get_template_processor(
|
||||
database=self._report_schedule.database
|
||||
)
|
||||
rendered_sql = sql_template.process_template(self._report_schedule.sql)
|
||||
df = self._report_schedule.database.get_df(rendered_sql)
|
||||
|
||||
if df.empty:
|
||||
return
|
||||
rows = df.to_records()
|
||||
if self._report_schedule.validator_type == ReportScheduleValidatorType.NOT_NULL:
|
||||
self._validate_not_null(rows)
|
||||
return
|
||||
self._validate_operator(rows)
|
||||
|
|
@ -103,6 +103,22 @@ class ReportScheduleDeleteFailedError(CommandException):
|
|||
message = _("Report Schedule delete failed.")
|
||||
|
||||
|
||||
class PruneReportScheduleLogFailedError(CommandException):
|
||||
message = _("Report Schedule log prune failed.")
|
||||
|
||||
|
||||
class ReportScheduleScreenshotFailedError(CommandException):
|
||||
message = _("Report Schedule execution failed when generating a screenshot.")
|
||||
|
||||
|
||||
class ReportScheduleExecuteUnexpectedError(CommandException):
|
||||
message = _("Report Schedule execution got an unexpected error.")
|
||||
|
||||
|
||||
class ReportSchedulePreviousWorkingError(CommandException):
|
||||
message = _("Report Schedule is still working, refusing to re-compute.")
|
||||
|
||||
|
||||
class ReportScheduleNameUniquenessValidationError(ValidationError):
|
||||
"""
|
||||
Marshmallow validation error for Report Schedule name already exists
|
||||
|
|
@ -110,3 +126,24 @@ class ReportScheduleNameUniquenessValidationError(ValidationError):
|
|||
|
||||
def __init__(self) -> None:
|
||||
super().__init__([_("Name must be unique")], field_name="name")
|
||||
|
||||
|
||||
class AlertQueryMultipleRowsError(CommandException):
|
||||
|
||||
message = _("Alert query returned more then one row.")
|
||||
|
||||
|
||||
class AlertQueryMultipleColumnsError(CommandException):
|
||||
message = _("Alert query returned more then one column.")
|
||||
|
||||
|
||||
class AlertQueryInvalidTypeError(CommandException):
|
||||
message = _("Alert query returned a non-number value.")
|
||||
|
||||
|
||||
class ReportScheduleAlertGracePeriodError(CommandException):
|
||||
message = _("Alert fired during grace period.")
|
||||
|
||||
|
||||
class ReportScheduleNotificationError(CommandException):
|
||||
message = _("Alert on grace period")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,256 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import app, thumbnail_cache
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.extensions import security_manager
|
||||
from superset.models.reports import (
|
||||
ReportExecutionLog,
|
||||
ReportLogState,
|
||||
ReportSchedule,
|
||||
ReportScheduleType,
|
||||
)
|
||||
from superset.reports.commands.alert import AlertCommand
|
||||
from superset.reports.commands.exceptions import (
|
||||
ReportScheduleAlertGracePeriodError,
|
||||
ReportScheduleExecuteUnexpectedError,
|
||||
ReportScheduleNotFoundError,
|
||||
ReportScheduleNotificationError,
|
||||
ReportSchedulePreviousWorkingError,
|
||||
ReportScheduleScreenshotFailedError,
|
||||
)
|
||||
from superset.reports.dao import ReportScheduleDAO
|
||||
from superset.reports.notifications import create_notification
|
||||
from superset.reports.notifications.base import NotificationContent, ScreenshotData
|
||||
from superset.reports.notifications.exceptions import NotificationError
|
||||
from superset.utils.celery import session_scope
|
||||
from superset.utils.screenshots import (
|
||||
BaseScreenshot,
|
||||
ChartScreenshot,
|
||||
DashboardScreenshot,
|
||||
)
|
||||
from superset.utils.urls import get_url_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncExecuteReportScheduleCommand(BaseCommand):
|
||||
"""
|
||||
Execute all types of report schedules.
|
||||
- On reports takes chart or dashboard screenshots and sends configured notifications
|
||||
- On Alerts uses related Command AlertCommand and sends configured notifications
|
||||
"""
|
||||
|
||||
def __init__(self, model_id: int, scheduled_dttm: datetime):
|
||||
self._model_id = model_id
|
||||
self._model: Optional[ReportSchedule] = None
|
||||
self._scheduled_dttm = scheduled_dttm
|
||||
|
||||
def set_state_and_log(
|
||||
self,
|
||||
session: Session,
|
||||
start_dttm: datetime,
|
||||
state: ReportLogState,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Updates current ReportSchedule state and TS. If on final state writes the log
|
||||
for this execution
|
||||
"""
|
||||
now_dttm = datetime.utcnow()
|
||||
if state == ReportLogState.WORKING:
|
||||
self.set_state(session, state, now_dttm)
|
||||
return
|
||||
self.set_state(session, state, now_dttm)
|
||||
self.create_log(
|
||||
session, start_dttm, now_dttm, state, error_message=error_message,
|
||||
)
|
||||
|
||||
def set_state(
|
||||
self, session: Session, state: ReportLogState, dttm: datetime
|
||||
) -> None:
|
||||
"""
|
||||
Set the current report schedule state, on this case we want to
|
||||
commit immediately
|
||||
"""
|
||||
if self._model:
|
||||
self._model.last_state = state
|
||||
self._model.last_eval_dttm = dttm
|
||||
session.commit()
|
||||
|
||||
def create_log( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
session: Session,
|
||||
start_dttm: datetime,
|
||||
end_dttm: datetime,
|
||||
state: ReportLogState,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Creates a Report execution log, uses the current computed last_value for Alerts
|
||||
"""
|
||||
if self._model:
|
||||
log = ReportExecutionLog(
|
||||
scheduled_dttm=self._scheduled_dttm,
|
||||
start_dttm=start_dttm,
|
||||
end_dttm=end_dttm,
|
||||
value=self._model.last_value,
|
||||
value_row_json=self._model.last_value_row_json,
|
||||
state=state,
|
||||
error_message=error_message,
|
||||
report_schedule=self._model,
|
||||
)
|
||||
session.add(log)
|
||||
|
||||
@staticmethod
|
||||
def _get_url(report_schedule: ReportSchedule, user_friendly: bool = False) -> str:
|
||||
"""
|
||||
Get the url for this report schedule: chart or dashboard
|
||||
"""
|
||||
if report_schedule.chart:
|
||||
return get_url_path(
|
||||
"Superset.slice",
|
||||
user_friendly=user_friendly,
|
||||
slice_id=report_schedule.chart_id,
|
||||
standalone="true",
|
||||
)
|
||||
return get_url_path(
|
||||
"Superset.dashboard",
|
||||
user_friendly=user_friendly,
|
||||
dashboard_id_or_slug=report_schedule.dashboard_id,
|
||||
)
|
||||
|
||||
def _get_screenshot(self, report_schedule: ReportSchedule) -> ScreenshotData:
|
||||
"""
|
||||
Get a chart or dashboard screenshot
|
||||
:raises: ReportScheduleScreenshotFailedError
|
||||
"""
|
||||
url = self._get_url(report_schedule)
|
||||
screenshot: Optional[BaseScreenshot] = None
|
||||
if report_schedule.chart:
|
||||
screenshot = ChartScreenshot(url, report_schedule.chart.digest)
|
||||
else:
|
||||
screenshot = DashboardScreenshot(url, report_schedule.dashboard.digest)
|
||||
image_url = self._get_url(report_schedule, user_friendly=True)
|
||||
user = security_manager.find_user(app.config["THUMBNAIL_SELENIUM_USER"])
|
||||
image_data = screenshot.compute_and_cache(
|
||||
user=user, cache=thumbnail_cache, force=True,
|
||||
)
|
||||
if not image_data:
|
||||
raise ReportScheduleScreenshotFailedError()
|
||||
return ScreenshotData(url=image_url, image=image_data)
|
||||
|
||||
def _get_notification_content(
|
||||
self, report_schedule: ReportSchedule
|
||||
) -> NotificationContent:
|
||||
"""
|
||||
Gets a notification content, this is composed by a title and a screenshot
|
||||
:raises: ReportScheduleScreenshotFailedError
|
||||
"""
|
||||
screenshot_data = self._get_screenshot(report_schedule)
|
||||
if report_schedule.chart:
|
||||
name = report_schedule.chart.slice_name
|
||||
else:
|
||||
name = report_schedule.dashboard.dashboard_title
|
||||
return NotificationContent(name=name, screenshot=screenshot_data)
|
||||
|
||||
def _send(self, report_schedule: ReportSchedule) -> None:
|
||||
"""
|
||||
Creates the notification content and sends them to all recipients
|
||||
|
||||
:raises: ReportScheduleNotificationError
|
||||
"""
|
||||
notification_errors = []
|
||||
notification_content = self._get_notification_content(report_schedule)
|
||||
for recipient in report_schedule.recipients:
|
||||
notification = create_notification(recipient, notification_content)
|
||||
try:
|
||||
notification.send()
|
||||
except NotificationError as ex:
|
||||
# collect notification errors but keep processing them
|
||||
notification_errors.append(str(ex))
|
||||
if notification_errors:
|
||||
raise ReportScheduleNotificationError(";".join(notification_errors))
|
||||
|
||||
def run(self) -> None:
|
||||
with session_scope(nullpool=True) as session:
|
||||
try:
|
||||
start_dttm = datetime.utcnow()
|
||||
self.validate(session=session)
|
||||
if not self._model:
|
||||
raise ReportScheduleExecuteUnexpectedError()
|
||||
self.set_state_and_log(session, start_dttm, ReportLogState.WORKING)
|
||||
# If it's an alert check if the alert is triggered
|
||||
if self._model.type == ReportScheduleType.ALERT:
|
||||
if not AlertCommand(self._model).run():
|
||||
self.set_state_and_log(session, start_dttm, ReportLogState.NOOP)
|
||||
return
|
||||
|
||||
self._send(self._model)
|
||||
|
||||
# Log, state and TS
|
||||
self.set_state_and_log(session, start_dttm, ReportLogState.SUCCESS)
|
||||
except ReportScheduleAlertGracePeriodError as ex:
|
||||
self.set_state_and_log(
|
||||
session, start_dttm, ReportLogState.NOOP, error_message=str(ex)
|
||||
)
|
||||
except ReportSchedulePreviousWorkingError as ex:
|
||||
self.create_log(
|
||||
session,
|
||||
start_dttm,
|
||||
datetime.utcnow(),
|
||||
state=ReportLogState.ERROR,
|
||||
error_message=str(ex),
|
||||
)
|
||||
session.commit()
|
||||
raise
|
||||
except CommandException as ex:
|
||||
self.set_state_and_log(
|
||||
session, start_dttm, ReportLogState.ERROR, error_message=str(ex)
|
||||
)
|
||||
# We want to actually commit the state and log inside the scope
|
||||
session.commit()
|
||||
raise
|
||||
|
||||
def validate( # pylint: disable=arguments-differ
|
||||
self, session: Session = None
|
||||
) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = ReportScheduleDAO.find_by_id(self._model_id, session=session)
|
||||
if not self._model:
|
||||
raise ReportScheduleNotFoundError()
|
||||
# Avoid overlap processing
|
||||
if self._model.last_state == ReportLogState.WORKING:
|
||||
raise ReportSchedulePreviousWorkingError()
|
||||
# Check grace period
|
||||
if self._model.type == ReportScheduleType.ALERT:
|
||||
last_success = ReportScheduleDAO.find_last_success_log(session)
|
||||
if (
|
||||
last_success
|
||||
and self._model.last_state
|
||||
in (ReportLogState.SUCCESS, ReportLogState.NOOP)
|
||||
and self._model.grace_period
|
||||
and datetime.utcnow() - timedelta(seconds=self._model.grace_period)
|
||||
< last_success.end_dttm
|
||||
):
|
||||
raise ReportScheduleAlertGracePeriodError()
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.models.reports import ReportSchedule
|
||||
from superset.reports.dao import ReportScheduleDAO
|
||||
from superset.utils.celery import session_scope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncPruneReportScheduleLogCommand(BaseCommand):
|
||||
"""
|
||||
Prunes logs from all report schedules
|
||||
"""
|
||||
|
||||
def __init__(self, worker_context: bool = True):
|
||||
self._worker_context = worker_context
|
||||
|
||||
def run(self) -> None:
|
||||
with session_scope(nullpool=True) as session:
|
||||
self.validate()
|
||||
for report_schedule in session.query(ReportSchedule).all():
|
||||
from_date = datetime.utcnow() - timedelta(
|
||||
days=report_schedule.log_retention
|
||||
)
|
||||
ReportScheduleDAO.bulk_delete_logs(
|
||||
report_schedule, from_date, session=session, commit=False
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
pass
|
||||
|
|
@ -15,15 +15,22 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.dao.base import BaseDAO
|
||||
from superset.dao.exceptions import DAOCreateFailedError, DAODeleteFailedError
|
||||
from superset.extensions import db
|
||||
from superset.models.reports import ReportRecipients, ReportSchedule
|
||||
from superset.models.reports import (
|
||||
ReportExecutionLog,
|
||||
ReportLogState,
|
||||
ReportRecipients,
|
||||
ReportSchedule,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -135,3 +142,49 @@ class ReportScheduleDAO(BaseDAO):
|
|||
except SQLAlchemyError:
|
||||
db.session.rollback()
|
||||
raise DAOCreateFailedError
|
||||
|
||||
@staticmethod
|
||||
def find_active(session: Optional[Session] = None) -> List[ReportSchedule]:
|
||||
"""
|
||||
Find all active reports. If session is passed it will be used instead of the
|
||||
default `db.session`, this is useful when on a celery worker session context
|
||||
"""
|
||||
session = session or db.session
|
||||
return (
|
||||
session.query(ReportSchedule).filter(ReportSchedule.active.is_(True)).all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def find_last_success_log(
|
||||
session: Optional[Session] = None,
|
||||
) -> Optional[ReportExecutionLog]:
|
||||
"""
|
||||
Finds last success execution log
|
||||
"""
|
||||
session = session or db.session
|
||||
return (
|
||||
session.query(ReportExecutionLog)
|
||||
.filter(ReportExecutionLog.state == ReportLogState.SUCCESS)
|
||||
.order_by(ReportExecutionLog.end_dttm.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def bulk_delete_logs(
|
||||
model: ReportSchedule,
|
||||
from_date: datetime,
|
||||
session: Optional[Session] = None,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
session = session or db.session
|
||||
try:
|
||||
session.query(ReportExecutionLog).filter(
|
||||
ReportExecutionLog.report_schedule == model,
|
||||
ReportExecutionLog.end_dttm < from_date,
|
||||
).delete(synchronize_session="fetch")
|
||||
if commit:
|
||||
session.commit()
|
||||
except SQLAlchemyError as ex:
|
||||
if commit:
|
||||
session.rollback()
|
||||
raise ex
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.models.reports import ReportRecipients
|
||||
from superset.reports.notifications.base import BaseNotification, NotificationContent
|
||||
from superset.reports.notifications.email import EmailNotification
|
||||
from superset.reports.notifications.slack import SlackNotification
|
||||
|
||||
|
||||
def create_notification(
|
||||
recipient: ReportRecipients, screenshot_data: NotificationContent
|
||||
) -> BaseNotification:
|
||||
"""
|
||||
Notification polymorphic factory
|
||||
Returns the Notification class for the recipient type
|
||||
"""
|
||||
for plugin in BaseNotification.plugins:
|
||||
if plugin.type == recipient.type:
|
||||
return plugin(recipient, screenshot_data)
|
||||
raise Exception("Recipient type not supported")
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from superset.models.reports import ReportRecipients, ReportRecipientType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScreenshotData:
|
||||
url: str # url to chart/dashboard for this screenshot
|
||||
image: bytes # bytes for the screenshot
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotificationContent:
|
||||
name: str
|
||||
screenshot: ScreenshotData
|
||||
|
||||
|
||||
class BaseNotification: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
Serves has base for all notifications and creates a simple plugin system
|
||||
for extending future implementations.
|
||||
Child implementations get automatically registered and should identify the
|
||||
notification type
|
||||
"""
|
||||
|
||||
plugins: List[Type["BaseNotification"]] = []
|
||||
type: Optional[ReportRecipientType] = None
|
||||
"""
|
||||
Child classes set their notification type ex: `type = "email"` this string will be
|
||||
used by ReportRecipients.type to map to the correct implementation
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(*args, **kwargs) # type: ignore
|
||||
cls.plugins.append(cls)
|
||||
|
||||
def __init__(
|
||||
self, recipient: ReportRecipients, content: NotificationContent
|
||||
) -> None:
|
||||
self._recipient = recipient
|
||||
self._content = content
|
||||
|
||||
def send(self) -> None:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from email.utils import make_msgid, parseaddr
|
||||
from typing import Dict
|
||||
|
||||
from flask_babel import gettext as __
|
||||
|
||||
from superset import app
|
||||
from superset.models.reports import ReportRecipientType
|
||||
from superset.reports.notifications.base import BaseNotification
|
||||
from superset.reports.notifications.exceptions import NotificationError
|
||||
from superset.utils.core import send_email_smtp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmailContent:
|
||||
body: str
|
||||
images: Dict[str, bytes]
|
||||
|
||||
|
||||
class EmailNotification(BaseNotification): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
Sends an email notification for a report recipient
|
||||
"""
|
||||
|
||||
type = ReportRecipientType.EMAIL
|
||||
|
||||
@staticmethod
|
||||
def _get_smtp_domain() -> str:
|
||||
return parseaddr(app.config["SMTP_MAIL_FROM"])[1].split("@")[1]
|
||||
|
||||
def _get_content(self) -> EmailContent:
|
||||
# Get the domain from the 'From' address ..
|
||||
# and make a message id without the < > in the ends
|
||||
domain = self._get_smtp_domain()
|
||||
msgid = make_msgid(domain)[1:-1]
|
||||
|
||||
image = {msgid: self._content.screenshot.image}
|
||||
body = __(
|
||||
"""
|
||||
<b><a href="%(url)s">Explore in Superset</a></b><p></p>
|
||||
<img src="cid:%(msgid)s">
|
||||
""",
|
||||
url=self._content.screenshot.url,
|
||||
msgid=msgid,
|
||||
)
|
||||
return EmailContent(body=body, images=image)
|
||||
|
||||
def _get_subject(self) -> str:
|
||||
return __(
|
||||
"%(prefix)s %(title)s",
|
||||
prefix=app.config["EMAIL_REPORTS_SUBJECT_PREFIX"],
|
||||
title=self._content.name,
|
||||
)
|
||||
|
||||
def _get_to(self) -> str:
|
||||
return json.loads(self._recipient.recipient_config_json)["target"]
|
||||
|
||||
def send(self) -> None:
|
||||
subject = self._get_subject()
|
||||
content = self._get_content()
|
||||
to = self._get_to()
|
||||
try:
|
||||
send_email_smtp(
|
||||
to,
|
||||
subject,
|
||||
content.body,
|
||||
app.config,
|
||||
files=[],
|
||||
data=None,
|
||||
images=content.images,
|
||||
bcc="",
|
||||
mime_subtype="related",
|
||||
dryrun=False,
|
||||
)
|
||||
logger.info("Report sent to email")
|
||||
except Exception as ex:
|
||||
raise NotificationError(ex)
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
class NotificationError(Exception):
|
||||
pass
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
import logging
|
||||
from io import IOBase
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
from flask_babel import gettext as __
|
||||
from retry.api import retry
|
||||
from slack import WebClient
|
||||
from slack.errors import SlackApiError, SlackClientError
|
||||
from slack.web.slack_response import SlackResponse
|
||||
|
||||
from superset import app
|
||||
from superset.models.reports import ReportRecipientType
|
||||
from superset.reports.notifications.base import BaseNotification
|
||||
from superset.reports.notifications.exceptions import NotificationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackNotification(BaseNotification): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
Sends a slack notification for a report recipient
|
||||
"""
|
||||
|
||||
type = ReportRecipientType.SLACK
|
||||
|
||||
def _get_channel(self) -> str:
|
||||
return json.loads(self._recipient.recipient_config_json)["target"]
|
||||
|
||||
def _get_body(self) -> str:
|
||||
return __(
|
||||
"""
|
||||
*%(name)s*\n
|
||||
<%(url)s|Explore in Superset>
|
||||
""",
|
||||
name=self._content.name,
|
||||
url=self._content.screenshot.url,
|
||||
)
|
||||
|
||||
def _get_inline_screenshot(self) -> Optional[Union[str, IOBase, bytes]]:
|
||||
return self._content.screenshot.image
|
||||
|
||||
@retry(SlackApiError, delay=10, backoff=2, tries=5)
|
||||
def send(self) -> None:
|
||||
file = self._get_inline_screenshot()
|
||||
channel = self._get_channel()
|
||||
body = self._get_body()
|
||||
|
||||
try:
|
||||
client = WebClient(
|
||||
token=app.config["SLACK_API_TOKEN"], proxy=app.config["SLACK_PROXY"]
|
||||
)
|
||||
# files_upload returns SlackResponse as we run it in sync mode.
|
||||
if file:
|
||||
response = cast(
|
||||
SlackResponse,
|
||||
client.files_upload(
|
||||
channels=channel,
|
||||
file=file,
|
||||
initial_comment=body,
|
||||
title="subject",
|
||||
),
|
||||
)
|
||||
assert response["file"], str(response) # the uploaded file
|
||||
else:
|
||||
response = cast(
|
||||
SlackResponse, client.chat_postMessage(channel=channel, text=body),
|
||||
)
|
||||
assert response["message"]["text"], str(response)
|
||||
logger.info("Report sent to slack")
|
||||
except SlackClientError as ex:
|
||||
raise NotificationError(ex)
|
||||
|
|
@ -29,7 +29,7 @@ create_app()
|
|||
|
||||
# Need to import late, as the celery_app will have been setup by "create_app()"
|
||||
# pylint: disable=wrong-import-position, unused-import
|
||||
from . import cache, schedules # isort:skip
|
||||
from . import cache, schedules, scheduler # isort:skip
|
||||
|
||||
# Export the celery app globally for Celery (as run on the cmd line) to find
|
||||
app = celery_app
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Iterator
|
||||
|
||||
import croniter
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.extensions import celery_app
|
||||
from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand
|
||||
from superset.reports.commands.log_prune import AsyncPruneReportScheduleLogCommand
|
||||
from superset.reports.dao import ReportScheduleDAO
|
||||
from superset.utils.celery import session_scope
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cron_schedule_window(cron: str, window_size: int = 10) -> Iterator[datetime]:
|
||||
utc_now = datetime.utcnow()
|
||||
start_at = utc_now - timedelta(seconds=1)
|
||||
stop_at = utc_now + timedelta(seconds=window_size)
|
||||
crons = croniter.croniter(cron, start_at)
|
||||
for schedule in crons.all_next(datetime):
|
||||
if schedule >= stop_at:
|
||||
break
|
||||
yield schedule
|
||||
|
||||
|
||||
@celery_app.task(name="reports.scheduler")
|
||||
def scheduler() -> None:
|
||||
"""
|
||||
Celery beat main scheduler for reports
|
||||
"""
|
||||
with session_scope(nullpool=True) as session:
|
||||
active_schedules = ReportScheduleDAO.find_active(session)
|
||||
for active_schedule in active_schedules:
|
||||
for schedule in cron_schedule_window(active_schedule.crontab):
|
||||
execute.apply_async((active_schedule.id, schedule,), eta=schedule)
|
||||
|
||||
|
||||
@celery_app.task(name="reports.execute")
|
||||
def execute(report_schedule_id: int, scheduled_dttm: datetime) -> None:
|
||||
try:
|
||||
AsyncExecuteReportScheduleCommand(report_schedule_id, scheduled_dttm).run()
|
||||
except CommandException as ex:
|
||||
logger.error("An exception occurred while executing the report: %s", ex)
|
||||
|
||||
|
||||
@celery_app.task(name="reports.prune_log")
|
||||
def prune_log() -> None:
|
||||
try:
|
||||
AsyncPruneReportScheduleLogCommand().run()
|
||||
except CommandException as ex:
|
||||
logger.error("An exception occurred while pruning report schedule logs: %s", ex)
|
||||
|
|
@ -20,11 +20,15 @@ from typing import Any
|
|||
from flask import current_app, url_for
|
||||
|
||||
|
||||
def headless_url(path: str) -> str:
|
||||
base_url = current_app.config.get("WEBDRIVER_BASEURL", "")
|
||||
def headless_url(path: str, user_friendly: bool = False) -> str:
|
||||
base_url = (
|
||||
current_app.config["WEBDRIVER_BASEURL_USER_FRIENDLY"]
|
||||
if user_friendly
|
||||
else current_app.config["WEBDRIVER_BASEURL"]
|
||||
)
|
||||
return urllib.parse.urljoin(base_url, path)
|
||||
|
||||
|
||||
def get_url_path(view: str, **kwargs: Any) -> str:
|
||||
def get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str:
|
||||
with current_app.test_request_context():
|
||||
return headless_url(url_for(view, **kwargs))
|
||||
return headless_url(url_for(view, **kwargs), user_friendly=user_friendly)
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ from superset.models.reports import (
|
|||
)
|
||||
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.reports.utils import insert_report_schedule
|
||||
from superset.utils.core import get_example_database
|
||||
|
||||
|
||||
|
|
@ -47,48 +48,6 @@ REPORTS_COUNT = 10
|
|||
|
||||
|
||||
class TestReportSchedulesApi(SupersetTestCase):
|
||||
def insert_report_schedule(
|
||||
self,
|
||||
type: str,
|
||||
name: str,
|
||||
crontab: str,
|
||||
sql: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
chart: Optional[Slice] = None,
|
||||
dashboard: Optional[Dashboard] = None,
|
||||
database: Optional[Database] = None,
|
||||
owners: Optional[List[User]] = None,
|
||||
validator_type: Optional[str] = None,
|
||||
validator_config_json: Optional[str] = None,
|
||||
log_retention: Optional[int] = None,
|
||||
grace_period: Optional[int] = None,
|
||||
recipients: Optional[List[ReportRecipients]] = None,
|
||||
logs: Optional[List[ReportExecutionLog]] = None,
|
||||
) -> ReportSchedule:
|
||||
owners = owners or []
|
||||
recipients = recipients or []
|
||||
logs = logs or []
|
||||
report_schedule = ReportSchedule(
|
||||
type=type,
|
||||
name=name,
|
||||
crontab=crontab,
|
||||
sql=sql,
|
||||
description=description,
|
||||
chart=chart,
|
||||
dashboard=dashboard,
|
||||
database=database,
|
||||
owners=owners,
|
||||
validator_type=validator_type,
|
||||
validator_config_json=validator_config_json,
|
||||
log_retention=log_retention,
|
||||
grace_period=grace_period,
|
||||
recipients=recipients,
|
||||
logs=logs,
|
||||
)
|
||||
db.session.add(report_schedule)
|
||||
db.session.commit()
|
||||
return report_schedule
|
||||
|
||||
@pytest.fixture()
|
||||
def create_report_schedules(self):
|
||||
with self.create_app().app_context():
|
||||
|
|
@ -116,7 +75,7 @@ class TestReportSchedulesApi(SupersetTestCase):
|
|||
)
|
||||
)
|
||||
report_schedules.append(
|
||||
self.insert_report_schedule(
|
||||
insert_report_schedule(
|
||||
type=ReportScheduleType.ALERT,
|
||||
name=f"name{cx}",
|
||||
crontab=f"*/{cx} * * * *",
|
||||
|
|
@ -169,10 +128,6 @@ class TestReportSchedulesApi(SupersetTestCase):
|
|||
"last_value_row_json": report_schedule.last_value_row_json,
|
||||
"log_retention": report_schedule.log_retention,
|
||||
"name": report_schedule.name,
|
||||
"owners": [
|
||||
{"first_name": "admin", "id": 1, "last_name": "user"},
|
||||
{"first_name": "alpha", "id": 5, "last_name": "user"},
|
||||
],
|
||||
"recipients": [
|
||||
{
|
||||
"id": report_schedule.recipients[0].id,
|
||||
|
|
@ -184,7 +139,16 @@ class TestReportSchedulesApi(SupersetTestCase):
|
|||
"validator_config_json": report_schedule.validator_config_json,
|
||||
"validator_type": report_schedule.validator_type,
|
||||
}
|
||||
assert data["result"] == expected_result
|
||||
for key in expected_result:
|
||||
assert data["result"][key] == expected_result[key]
|
||||
# needed because order may vary
|
||||
assert {"first_name": "admin", "id": 1, "last_name": "user"} in data["result"][
|
||||
"owners"
|
||||
]
|
||||
assert {"first_name": "alpha", "id": 5, "last_name": "user"} in data["result"][
|
||||
"owners"
|
||||
]
|
||||
assert len(data["result"]["owners"]) == 2
|
||||
|
||||
def test_info_report_schedule(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,531 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from contextlib2 import contextmanager
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from superset import db
|
||||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.reports import (
|
||||
ReportExecutionLog,
|
||||
ReportLogState,
|
||||
ReportRecipients,
|
||||
ReportRecipientType,
|
||||
ReportSchedule,
|
||||
ReportScheduleType,
|
||||
ReportScheduleValidatorType,
|
||||
)
|
||||
from superset.models.slice import Slice
|
||||
from superset.reports.commands.exceptions import (
|
||||
AlertQueryMultipleColumnsError,
|
||||
AlertQueryMultipleRowsError,
|
||||
ReportScheduleNotFoundError,
|
||||
ReportScheduleNotificationError,
|
||||
ReportSchedulePreviousWorkingError,
|
||||
)
|
||||
from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand
|
||||
from superset.utils.core import get_example_database
|
||||
from tests.reports.utils import insert_report_schedule
|
||||
from tests.test_app import app
|
||||
from tests.utils import read_fixture
|
||||
|
||||
|
||||
def get_target_from_report_schedule(report_schedule) -> List[str]:
|
||||
return [
|
||||
json.loads(recipient.recipient_config_json)["target"]
|
||||
for recipient in report_schedule.recipients
|
||||
]
|
||||
|
||||
|
||||
def assert_log(state: str, error_message: Optional[str] = None):
|
||||
db.session.commit()
|
||||
logs = db.session.query(ReportExecutionLog).all()
|
||||
assert len(logs) == 1
|
||||
assert logs[0].error_message == error_message
|
||||
assert logs[0].state == state
|
||||
|
||||
|
||||
def create_report_notification(
|
||||
email_target: Optional[str] = None,
|
||||
slack_channel: Optional[str] = None,
|
||||
chart: Optional[Slice] = None,
|
||||
dashboard: Optional[Dashboard] = None,
|
||||
database: Optional[Database] = None,
|
||||
sql: Optional[str] = None,
|
||||
report_type: Optional[str] = None,
|
||||
validator_type: Optional[str] = None,
|
||||
validator_config_json: Optional[str] = None,
|
||||
) -> ReportSchedule:
|
||||
report_type = report_type or ReportScheduleType.REPORT
|
||||
target = email_target or slack_channel
|
||||
config_json = {"target": target}
|
||||
if slack_channel:
|
||||
recipient = ReportRecipients(
|
||||
type=ReportRecipientType.SLACK,
|
||||
recipient_config_json=json.dumps(config_json),
|
||||
)
|
||||
else:
|
||||
recipient = ReportRecipients(
|
||||
type=ReportRecipientType.EMAIL,
|
||||
recipient_config_json=json.dumps(config_json),
|
||||
)
|
||||
|
||||
report_schedule = insert_report_schedule(
|
||||
type=report_type,
|
||||
name=f"report",
|
||||
crontab=f"0 9 * * *",
|
||||
description=f"Daily report",
|
||||
sql=sql,
|
||||
chart=chart,
|
||||
dashboard=dashboard,
|
||||
database=database,
|
||||
recipients=[recipient],
|
||||
validator_type=validator_type,
|
||||
validator_config_json=validator_config_json,
|
||||
)
|
||||
return report_schedule
|
||||
|
||||
|
||||
@pytest.yield_fixture()
|
||||
def create_report_email_chart():
|
||||
with app.app_context():
|
||||
chart = db.session.query(Slice).first()
|
||||
report_schedule = create_report_notification(
|
||||
email_target="target@email.com", chart=chart
|
||||
)
|
||||
yield report_schedule
|
||||
|
||||
db.session.delete(report_schedule)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.yield_fixture()
|
||||
def create_report_email_dashboard():
|
||||
with app.app_context():
|
||||
dashboard = db.session.query(Dashboard).first()
|
||||
report_schedule = create_report_notification(
|
||||
email_target="target@email.com", dashboard=dashboard
|
||||
)
|
||||
yield report_schedule
|
||||
|
||||
db.session.delete(report_schedule)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.yield_fixture()
|
||||
def create_report_slack_chart():
|
||||
with app.app_context():
|
||||
chart = db.session.query(Slice).first()
|
||||
report_schedule = create_report_notification(
|
||||
slack_channel="slack_channel", chart=chart
|
||||
)
|
||||
yield report_schedule
|
||||
|
||||
db.session.delete(report_schedule)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.yield_fixture()
|
||||
def create_report_slack_chart_working():
|
||||
with app.app_context():
|
||||
chart = db.session.query(Slice).first()
|
||||
report_schedule = create_report_notification(
|
||||
slack_channel="slack_channel", chart=chart
|
||||
)
|
||||
report_schedule.last_state = ReportLogState.WORKING
|
||||
db.session.commit()
|
||||
yield report_schedule
|
||||
|
||||
db.session.delete(report_schedule)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.yield_fixture(
|
||||
params=["alert1", "alert2", "alert3", "alert4", "alert5", "alert6", "alert7"]
|
||||
)
|
||||
def create_alert_email_chart(request):
|
||||
param_config = {
|
||||
"alert1": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": ">", "threshold": 9}',
|
||||
},
|
||||
"alert2": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": ">=", "threshold": 10}',
|
||||
},
|
||||
"alert3": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "<", "threshold": 11}',
|
||||
},
|
||||
"alert4": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "<=", "threshold": 10}',
|
||||
},
|
||||
"alert5": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "!=", "threshold": 11}',
|
||||
},
|
||||
"alert6": {
|
||||
"sql": "SELECT 'something' as metric",
|
||||
"validator_type": ReportScheduleValidatorType.NOT_NULL,
|
||||
"validator_config_json": "{}",
|
||||
},
|
||||
"alert7": {
|
||||
"sql": "SELECT {{ 5 + 5 }} as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "!=", "threshold": 11}',
|
||||
},
|
||||
}
|
||||
with app.app_context():
|
||||
chart = db.session.query(Slice).first()
|
||||
example_database = get_example_database()
|
||||
|
||||
report_schedule = create_report_notification(
|
||||
email_target="target@email.com",
|
||||
chart=chart,
|
||||
report_type=ReportScheduleType.ALERT,
|
||||
database=example_database,
|
||||
sql=param_config[request.param]["sql"],
|
||||
validator_type=param_config[request.param]["validator_type"],
|
||||
validator_config_json=param_config[request.param]["validator_config_json"],
|
||||
)
|
||||
yield report_schedule
|
||||
|
||||
db.session.delete(report_schedule)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def create_test_table_context(database: Database):
|
||||
database.get_sqla_engine().execute(
|
||||
"CREATE TABLE test_table AS SELECT 1 as first, 2 as second"
|
||||
)
|
||||
database.get_sqla_engine().execute(
|
||||
"INSERT INTO test_table (first, second) VALUES (1, 2)"
|
||||
)
|
||||
database.get_sqla_engine().execute(
|
||||
"INSERT INTO test_table (first, second) VALUES (3, 4)"
|
||||
)
|
||||
|
||||
yield db.session
|
||||
database.get_sqla_engine().execute("DROP TABLE test_table")
|
||||
|
||||
|
||||
@pytest.yield_fixture(
|
||||
params=["alert1", "alert2", "alert3", "alert4", "alert5", "alert6"]
|
||||
)
|
||||
def create_no_alert_email_chart(request):
|
||||
param_config = {
|
||||
"alert1": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "<", "threshold": 10}',
|
||||
},
|
||||
"alert2": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": ">=", "threshold": 11}',
|
||||
},
|
||||
"alert3": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "<", "threshold": 10}',
|
||||
},
|
||||
"alert4": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "<=", "threshold": 9}',
|
||||
},
|
||||
"alert5": {
|
||||
"sql": "SELECT 10 as metric",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "!=", "threshold": 10}',
|
||||
},
|
||||
"alert6": {
|
||||
"sql": "SELECT first from test_table where first=0",
|
||||
"validator_type": ReportScheduleValidatorType.NOT_NULL,
|
||||
"validator_config_json": "{}",
|
||||
},
|
||||
}
|
||||
with app.app_context():
|
||||
chart = db.session.query(Slice).first()
|
||||
example_database = get_example_database()
|
||||
with create_test_table_context(example_database):
|
||||
|
||||
report_schedule = create_report_notification(
|
||||
email_target="target@email.com",
|
||||
chart=chart,
|
||||
report_type=ReportScheduleType.ALERT,
|
||||
database=example_database,
|
||||
sql=param_config[request.param]["sql"],
|
||||
validator_type=param_config[request.param]["validator_type"],
|
||||
validator_config_json=param_config[request.param][
|
||||
"validator_config_json"
|
||||
],
|
||||
)
|
||||
yield report_schedule
|
||||
|
||||
db.session.delete(report_schedule)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.yield_fixture(params=["alert1", "alert2"])
|
||||
def create_mul_alert_email_chart(request):
|
||||
param_config = {
|
||||
"alert1": {
|
||||
"sql": "SELECT first from test_table",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "<", "threshold": 10}',
|
||||
},
|
||||
"alert2": {
|
||||
"sql": "SELECT first, second from test_table",
|
||||
"validator_type": ReportScheduleValidatorType.OPERATOR,
|
||||
"validator_config_json": '{"op": "<", "threshold": 10}',
|
||||
},
|
||||
}
|
||||
with app.app_context():
|
||||
chart = db.session.query(Slice).first()
|
||||
example_database = get_example_database()
|
||||
with create_test_table_context(example_database):
|
||||
|
||||
report_schedule = create_report_notification(
|
||||
email_target="target@email.com",
|
||||
chart=chart,
|
||||
report_type=ReportScheduleType.ALERT,
|
||||
database=example_database,
|
||||
sql=param_config[request.param]["sql"],
|
||||
validator_type=param_config[request.param]["validator_type"],
|
||||
validator_config_json=param_config[request.param][
|
||||
"validator_config_json"
|
||||
],
|
||||
)
|
||||
yield report_schedule
|
||||
|
||||
# needed for MySQL
|
||||
logs = (
|
||||
db.session.query(ReportExecutionLog)
|
||||
.filter(ReportExecutionLog.report_schedule == report_schedule)
|
||||
.all()
|
||||
)
|
||||
for log in logs:
|
||||
db.session.delete(log)
|
||||
db.session.commit()
|
||||
db.session.delete(report_schedule)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("create_report_email_chart")
|
||||
@patch("superset.reports.notifications.email.send_email_smtp")
|
||||
@patch("superset.utils.screenshots.ChartScreenshot.compute_and_cache")
|
||||
def test_email_chart_report_schedule(
|
||||
screenshot_mock, email_mock, create_report_email_chart
|
||||
):
|
||||
"""
|
||||
ExecuteReport Command: Test chart email report schedule
|
||||
"""
|
||||
# setup screenshot mock
|
||||
screenshot = read_fixture("sample.png")
|
||||
screenshot_mock.return_value = screenshot
|
||||
|
||||
with freeze_time("2020-01-01T00:00:00Z"):
|
||||
AsyncExecuteReportScheduleCommand(
|
||||
create_report_email_chart.id, datetime.utcnow()
|
||||
).run()
|
||||
|
||||
notification_targets = get_target_from_report_schedule(
|
||||
create_report_email_chart
|
||||
)
|
||||
# Assert the email smtp address
|
||||
assert email_mock.call_args[0][0] == notification_targets[0]
|
||||
# Assert the email inline screenshot
|
||||
smtp_images = email_mock.call_args[1]["images"]
|
||||
assert smtp_images[list(smtp_images.keys())[0]] == screenshot
|
||||
# Assert logs are correct
|
||||
assert_log(ReportLogState.SUCCESS)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("create_report_email_dashboard")
|
||||
@patch("superset.reports.notifications.email.send_email_smtp")
|
||||
@patch("superset.utils.screenshots.DashboardScreenshot.compute_and_cache")
|
||||
def test_email_dashboard_report_schedule(
|
||||
screenshot_mock, email_mock, create_report_email_dashboard
|
||||
):
|
||||
"""
|
||||
ExecuteReport Command: Test dashboard email report schedule
|
||||
"""
|
||||
# setup screenshot mock
|
||||
screenshot = read_fixture("sample.png")
|
||||
screenshot_mock.return_value = screenshot
|
||||
|
||||
with freeze_time("2020-01-01T00:00:00Z"):
|
||||
AsyncExecuteReportScheduleCommand(
|
||||
create_report_email_dashboard.id, datetime.utcnow()
|
||||
).run()
|
||||
|
||||
notification_targets = get_target_from_report_schedule(
|
||||
create_report_email_dashboard
|
||||
)
|
||||
# Assert the email smtp address
|
||||
assert email_mock.call_args[0][0] == notification_targets[0]
|
||||
# Assert the email inline screenshot
|
||||
smtp_images = email_mock.call_args[1]["images"]
|
||||
assert smtp_images[list(smtp_images.keys())[0]] == screenshot
|
||||
# Assert logs are correct
|
||||
assert_log(ReportLogState.SUCCESS)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("create_report_slack_chart")
|
||||
@patch("superset.reports.notifications.slack.WebClient.files_upload")
|
||||
@patch("superset.utils.screenshots.ChartScreenshot.compute_and_cache")
|
||||
def test_slack_chart_report_schedule(
|
||||
screenshot_mock, file_upload_mock, create_report_slack_chart
|
||||
):
|
||||
"""
|
||||
ExecuteReport Command: Test chart slack report schedule
|
||||
"""
|
||||
# setup screenshot mock
|
||||
screenshot = read_fixture("sample.png")
|
||||
screenshot_mock.return_value = screenshot
|
||||
|
||||
with freeze_time("2020-01-01T00:00:00Z"):
|
||||
AsyncExecuteReportScheduleCommand(
|
||||
create_report_slack_chart.id, datetime.utcnow()
|
||||
).run()
|
||||
|
||||
notification_targets = get_target_from_report_schedule(
|
||||
create_report_slack_chart
|
||||
)
|
||||
assert file_upload_mock.call_args[1]["channels"] == notification_targets[0]
|
||||
assert file_upload_mock.call_args[1]["file"] == screenshot
|
||||
|
||||
# Assert logs are correct
|
||||
assert_log(ReportLogState.SUCCESS)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("create_report_slack_chart")
|
||||
def test_report_schedule_not_found(create_report_slack_chart):
|
||||
"""
|
||||
ExecuteReport Command: Test report schedule not found
|
||||
"""
|
||||
max_id = db.session.query(func.max(ReportSchedule.id)).scalar()
|
||||
with pytest.raises(ReportScheduleNotFoundError):
|
||||
AsyncExecuteReportScheduleCommand(max_id + 1, datetime.utcnow()).run()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("create_report_slack_chart_working")
|
||||
def test_report_schedule_working(create_report_slack_chart_working):
|
||||
"""
|
||||
ExecuteReport Command: Test report schedule still working
|
||||
"""
|
||||
# setup screenshot mock
|
||||
with pytest.raises(ReportSchedulePreviousWorkingError):
|
||||
AsyncExecuteReportScheduleCommand(
|
||||
create_report_slack_chart_working.id, datetime.utcnow()
|
||||
).run()
|
||||
|
||||
assert_log(
|
||||
ReportLogState.ERROR, error_message=ReportSchedulePreviousWorkingError.message
|
||||
)
|
||||
assert create_report_slack_chart_working.last_state == ReportLogState.WORKING
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("create_report_email_dashboard")
|
||||
@patch("superset.reports.notifications.email.send_email_smtp")
|
||||
@patch("superset.utils.screenshots.DashboardScreenshot.compute_and_cache")
|
||||
def test_email_dashboard_report_fails(
|
||||
screenshot_mock, email_mock, create_report_email_dashboard
|
||||
):
|
||||
"""
|
||||
ExecuteReport Command: Test dashboard email report schedule notification fails
|
||||
"""
|
||||
# setup screenshot mock
|
||||
from smtplib import SMTPException
|
||||
|
||||
screenshot = read_fixture("sample.png")
|
||||
screenshot_mock.return_value = screenshot
|
||||
email_mock.side_effect = SMTPException("Could not connect to SMTP XPTO")
|
||||
|
||||
with pytest.raises(ReportScheduleNotificationError):
|
||||
AsyncExecuteReportScheduleCommand(
|
||||
create_report_email_dashboard.id, datetime.utcnow()
|
||||
).run()
|
||||
|
||||
assert_log(ReportLogState.ERROR, error_message="Could not connect to SMTP XPTO")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("create_alert_email_chart")
|
||||
@patch("superset.reports.notifications.email.send_email_smtp")
|
||||
@patch("superset.utils.screenshots.ChartScreenshot.compute_and_cache")
|
||||
def test_slack_chart_alert(screenshot_mock, email_mock, create_alert_email_chart):
|
||||
"""
|
||||
ExecuteReport Command: Test chart slack alert
|
||||
"""
|
||||
# setup screenshot mock
|
||||
screenshot = read_fixture("sample.png")
|
||||
screenshot_mock.return_value = screenshot
|
||||
|
||||
with freeze_time("2020-01-01T00:00:00Z"):
|
||||
AsyncExecuteReportScheduleCommand(
|
||||
create_alert_email_chart.id, datetime.utcnow()
|
||||
).run()
|
||||
|
||||
notification_targets = get_target_from_report_schedule(create_alert_email_chart)
|
||||
# Assert the email smtp address
|
||||
assert email_mock.call_args[0][0] == notification_targets[0]
|
||||
# Assert the email inline screenshot
|
||||
smtp_images = email_mock.call_args[1]["images"]
|
||||
assert smtp_images[list(smtp_images.keys())[0]] == screenshot
|
||||
# Assert logs are correct
|
||||
assert_log(ReportLogState.SUCCESS)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("create_no_alert_email_chart")
|
||||
def test_email_chart_no_alert(create_no_alert_email_chart):
|
||||
"""
|
||||
ExecuteReport Command: Test chart email no alert
|
||||
"""
|
||||
with freeze_time("2020-01-01T00:00:00Z"):
|
||||
AsyncExecuteReportScheduleCommand(
|
||||
create_no_alert_email_chart.id, datetime.utcnow()
|
||||
).run()
|
||||
assert_log(ReportLogState.NOOP)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("create_mul_alert_email_chart")
|
||||
def test_email_mul_alert(create_mul_alert_email_chart):
|
||||
"""
|
||||
ExecuteReport Command: Test chart email multiple rows
|
||||
"""
|
||||
with freeze_time("2020-01-01T00:00:00Z"):
|
||||
with pytest.raises(
|
||||
(AlertQueryMultipleRowsError, AlertQueryMultipleColumnsError)
|
||||
):
|
||||
AsyncExecuteReportScheduleCommand(
|
||||
create_mul_alert_email_chart.id, datetime.utcnow()
|
||||
).run()
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset import db
|
||||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.reports import ReportExecutionLog, ReportRecipients, ReportSchedule
|
||||
from superset.models.slice import Slice
|
||||
|
||||
|
||||
def insert_report_schedule(
|
||||
type: str,
|
||||
name: str,
|
||||
crontab: str,
|
||||
sql: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
chart: Optional[Slice] = None,
|
||||
dashboard: Optional[Dashboard] = None,
|
||||
database: Optional[Database] = None,
|
||||
owners: Optional[List[User]] = None,
|
||||
validator_type: Optional[str] = None,
|
||||
validator_config_json: Optional[str] = None,
|
||||
log_retention: Optional[int] = None,
|
||||
grace_period: Optional[int] = None,
|
||||
recipients: Optional[List[ReportRecipients]] = None,
|
||||
logs: Optional[List[ReportExecutionLog]] = None,
|
||||
) -> ReportSchedule:
|
||||
owners = owners or []
|
||||
recipients = recipients or []
|
||||
logs = logs or []
|
||||
report_schedule = ReportSchedule(
|
||||
type=type,
|
||||
name=name,
|
||||
crontab=crontab,
|
||||
sql=sql,
|
||||
description=description,
|
||||
chart=chart,
|
||||
dashboard=dashboard,
|
||||
database=database,
|
||||
owners=owners,
|
||||
validator_type=validator_type,
|
||||
validator_config_json=validator_config_json,
|
||||
log_retention=log_retention,
|
||||
grace_period=grace_period,
|
||||
recipients=recipients,
|
||||
logs=logs,
|
||||
)
|
||||
db.session.add(report_schedule)
|
||||
db.session.commit()
|
||||
return report_schedule
|
||||
Loading…
Reference in New Issue