fix: use nullpool in the celery workers (#10819)
* Use nullpool in the celery workers * Address comments Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
This commit is contained in:
parent
dd7f3d5402
commit
ac2937a6c5
|
|
@ -34,6 +34,7 @@ from superset import app, appbuilder, security_manager
|
|||
from superset.app import create_app
|
||||
from superset.extensions import celery_app, db
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.celery import session_scope
|
||||
from superset.utils.urls import get_url_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -619,6 +620,11 @@ def alert() -> None:
|
|||
from superset.tasks.schedules import schedule_window
|
||||
|
||||
click.secho("Processing one alert loop", fg="green")
|
||||
with session_scope(nullpool=True) as session:
|
||||
schedule_window(
|
||||
ScheduleType.alert, datetime.now() - timedelta(1000), datetime.now(), 6000
|
||||
report_type=ScheduleType.alert,
|
||||
start_at=datetime.now() - timedelta(1000),
|
||||
stop_at=datetime.now(),
|
||||
resolution=6000,
|
||||
session=session,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,33 +19,25 @@ import uuid
|
|||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from sys import getsizeof
|
||||
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import backoff
|
||||
import msgpack
|
||||
import pyarrow as pa
|
||||
import simplejson as json
|
||||
import sqlalchemy
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.task.base import Task
|
||||
from contextlib2 import contextmanager
|
||||
from flask_babel import lazy_gettext as _
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import (
|
||||
app,
|
||||
db,
|
||||
results_backend,
|
||||
results_backend_use_msgpack,
|
||||
security_manager,
|
||||
)
|
||||
from superset import app, results_backend, results_backend_use_msgpack, security_manager
|
||||
from superset.dataframe import df_to_records
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.extensions import celery_app
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.utils.celery import session_scope
|
||||
from superset.utils.core import (
|
||||
json_iso_dttm_ser,
|
||||
QuerySource,
|
||||
|
|
@ -121,35 +113,6 @@ def get_query(query_id: int, session: Session) -> Query:
|
|||
raise SqlLabException("Failed at getting query")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope(nullpool: bool) -> Iterator[Session]:
|
||||
"""Provide a transactional scope around a series of operations."""
|
||||
database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
|
||||
if "sqlite" in database_uri:
|
||||
logger.warning(
|
||||
"SQLite Database support for metadata databases will be removed \
|
||||
in a future version of Superset."
|
||||
)
|
||||
if nullpool:
|
||||
engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool)
|
||||
session_class = sessionmaker()
|
||||
session_class.configure(bind=engine)
|
||||
session = session_class()
|
||||
else:
|
||||
session = db.session()
|
||||
session.commit() # HACK
|
||||
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as ex:
|
||||
session.rollback()
|
||||
logger.exception(ex)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="sql_lab.get_sql_results",
|
||||
bind=True,
|
||||
|
|
|
|||
|
|
@ -20,22 +20,24 @@ from datetime import datetime
|
|||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import db
|
||||
from superset.models.alerts import Alert, SQLObservation
|
||||
from superset.sql_parse import ParsedQuery
|
||||
|
||||
logger = logging.getLogger("tasks.email_reports")
|
||||
|
||||
|
||||
def observe(alert_id: int) -> Optional[str]:
|
||||
# Session needs to be passed along in the celery workers and db.session cannot be used.
|
||||
# For more info see: https://github.com/apache/incubator-superset/issues/10530
|
||||
def observe(alert_id: int, session: Session) -> Optional[str]:
|
||||
"""
|
||||
Runs the SQL query in an alert's SQLObserver and then
|
||||
stores the result in a SQLObservation.
|
||||
Returns an error message if the observer value was not valid
|
||||
"""
|
||||
|
||||
alert = db.session.query(Alert).filter_by(id=alert_id).one()
|
||||
alert = session.query(Alert).filter_by(id=alert_id).one()
|
||||
sql_observer = alert.sql_observer[0]
|
||||
|
||||
value = None
|
||||
|
|
@ -57,8 +59,8 @@ def observe(alert_id: int) -> Optional[str]:
|
|||
error_msg=error_msg,
|
||||
)
|
||||
|
||||
db.session.add(observation)
|
||||
db.session.commit()
|
||||
session.add(observation)
|
||||
session.commit()
|
||||
|
||||
return error_msg
|
||||
|
||||
|
|
|
|||
|
|
@ -47,8 +47,9 @@ from selenium.common.exceptions import WebDriverException
|
|||
from selenium.webdriver import chrome, firefox
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import app, db, security_manager, thumbnail_cache
|
||||
from superset import app, security_manager, thumbnail_cache
|
||||
from superset.extensions import celery_app, machine_auth_provider_factory
|
||||
from superset.models.alerts import Alert, AlertLog
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
|
@ -62,6 +63,7 @@ from superset.models.slice import Slice
|
|||
from superset.tasks.alerts.observer import observe
|
||||
from superset.tasks.alerts.validator import get_validator_function
|
||||
from superset.tasks.slack_util import deliver_slack_msg
|
||||
from superset.utils.celery import session_scope
|
||||
from superset.utils.core import get_email_address_list, send_email_smtp
|
||||
from superset.utils.screenshots import ChartScreenshot, WebDriverProxy
|
||||
from superset.utils.urls import get_url_path
|
||||
|
|
@ -225,7 +227,7 @@ def destroy_webdriver(
|
|||
pass
|
||||
|
||||
|
||||
def deliver_dashboard(
|
||||
def deliver_dashboard( # pylint: disable=too-many-locals
|
||||
dashboard_id: int,
|
||||
recipients: Optional[str],
|
||||
slack_channel: Optional[str],
|
||||
|
|
@ -236,7 +238,8 @@ def deliver_dashboard(
|
|||
"""
|
||||
Given a schedule, delivery the dashboard as an email report
|
||||
"""
|
||||
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one()
|
||||
with session_scope(nullpool=True) as session:
|
||||
dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one()
|
||||
|
||||
dashboard_url = _get_url_path(
|
||||
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
|
||||
|
|
@ -362,8 +365,8 @@ def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportConte
|
|||
return ReportContent(body, data, None, slack_message, content)
|
||||
|
||||
|
||||
def _get_slice_screenshot(slice_id: int) -> ScreenshotData:
|
||||
slice_obj = db.session.query(Slice).get(slice_id)
|
||||
def _get_slice_screenshot(slice_id: int, session: Session) -> ScreenshotData:
|
||||
slice_obj = session.query(Slice).get(slice_id)
|
||||
|
||||
chart_url = get_url_path("Superset.slice", slice_id=slice_obj.id, standalone="true")
|
||||
screenshot = ChartScreenshot(chart_url, slice_obj.digest)
|
||||
|
|
@ -376,7 +379,7 @@ def _get_slice_screenshot(slice_id: int) -> ScreenshotData:
|
|||
user=user, cache=thumbnail_cache, force=True,
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
return ScreenshotData(image_url, image_data)
|
||||
|
||||
|
||||
|
|
@ -427,11 +430,12 @@ def deliver_slice( # pylint: disable=too-many-arguments
|
|||
delivery_type: EmailDeliveryType,
|
||||
email_format: SliceEmailReportFormat,
|
||||
deliver_as_group: bool,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Given a schedule, delivery the slice as an email report
|
||||
"""
|
||||
slc = db.session.query(Slice).filter_by(id=slice_id).one()
|
||||
slc = session.query(Slice).filter_by(id=slice_id).one()
|
||||
|
||||
if email_format == SliceEmailReportFormat.data:
|
||||
report_content = _get_slice_data(slc, delivery_type)
|
||||
|
|
@ -477,7 +481,8 @@ def schedule_email_report( # pylint: disable=unused-argument
|
|||
slack_channel: Optional[str] = None,
|
||||
) -> None:
|
||||
model_cls = get_scheduler_model(report_type)
|
||||
schedule = db.create_scoped_session().query(model_cls).get(schedule_id)
|
||||
with session_scope(nullpool=True) as session:
|
||||
schedule = session.query(model_cls).get(schedule_id)
|
||||
|
||||
# The user may have disabled the schedule. If so, ignore this
|
||||
if not schedule or not schedule.active:
|
||||
|
|
@ -487,7 +492,9 @@ def schedule_email_report( # pylint: disable=unused-argument
|
|||
recipients = recipients or schedule.recipients
|
||||
slack_channel = slack_channel or schedule.slack_channel
|
||||
logger.info(
|
||||
"Starting report for slack: %s and recipients: %s.", slack_channel, recipients
|
||||
"Starting report for slack: %s and recipients: %s.",
|
||||
slack_channel,
|
||||
recipients,
|
||||
)
|
||||
|
||||
if report_type == ScheduleType.dashboard:
|
||||
|
|
@ -506,6 +513,7 @@ def schedule_email_report( # pylint: disable=unused-argument
|
|||
schedule.delivery_type,
|
||||
schedule.email_format,
|
||||
schedule.deliver_as_group,
|
||||
session,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown report type")
|
||||
|
|
@ -529,9 +537,8 @@ def schedule_alert_query( # pylint: disable=unused-argument
|
|||
slack_channel: Optional[str] = None,
|
||||
) -> None:
|
||||
model_cls = get_scheduler_model(report_type)
|
||||
|
||||
try:
|
||||
schedule = db.session.query(model_cls).get(schedule_id)
|
||||
with session_scope(nullpool=True) as session:
|
||||
schedule = session.query(model_cls).get(schedule_id)
|
||||
|
||||
# The user may have disabled the schedule. If so, ignore this
|
||||
if not schedule or not schedule.active:
|
||||
|
|
@ -539,15 +546,11 @@ def schedule_alert_query( # pylint: disable=unused-argument
|
|||
return
|
||||
|
||||
if report_type == ScheduleType.alert:
|
||||
evaluate_alert(schedule.id, schedule.label, recipients, slack_channel)
|
||||
evaluate_alert(
|
||||
schedule.id, schedule.label, session, recipients, slack_channel
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown report type")
|
||||
except NoSuchColumnError as column_error:
|
||||
stats_logger.incr("run_alert_task.error.nosuchcolumnerror")
|
||||
raise column_error
|
||||
except ResourceClosedError as resource_error:
|
||||
stats_logger.incr("run_alert_task.error.resourceclosederror")
|
||||
raise resource_error
|
||||
|
||||
|
||||
class AlertState:
|
||||
|
|
@ -558,6 +561,7 @@ class AlertState:
|
|||
|
||||
def deliver_alert(
|
||||
alert_id: int,
|
||||
session: Session,
|
||||
recipients: Optional[str] = None,
|
||||
slack_channel: Optional[str] = None,
|
||||
) -> None:
|
||||
|
|
@ -566,7 +570,7 @@ def deliver_alert(
|
|||
to its respective email and slack recipients
|
||||
"""
|
||||
|
||||
alert = db.session.query(Alert).get(alert_id)
|
||||
alert = session.query(Alert).get(alert_id)
|
||||
|
||||
logging.info("Triggering alert: %s", alert)
|
||||
|
||||
|
|
@ -588,7 +592,7 @@ def deliver_alert(
|
|||
str(alert.observations[-1].value),
|
||||
validation_error_message,
|
||||
_get_url_path("AlertModelView.show", user_friendly=True, pk=alert_id),
|
||||
_get_slice_screenshot(alert.slice.id),
|
||||
_get_slice_screenshot(alert.slice.id, session),
|
||||
)
|
||||
else:
|
||||
# TODO: dashboard delivery!
|
||||
|
|
@ -668,6 +672,7 @@ def deliver_slack_alert(alert_content: AlertContent, slack_channel: str) -> None
|
|||
def evaluate_alert(
|
||||
alert_id: int,
|
||||
label: str,
|
||||
session: Session,
|
||||
recipients: Optional[str] = None,
|
||||
slack_channel: Optional[str] = None,
|
||||
) -> None:
|
||||
|
|
@ -680,7 +685,7 @@ def evaluate_alert(
|
|||
|
||||
try:
|
||||
logger.info("Querying observers for alert <%s:%s>", alert_id, label)
|
||||
error_msg = observe(alert_id)
|
||||
error_msg = observe(alert_id, session)
|
||||
if error_msg:
|
||||
state = AlertState.ERROR
|
||||
logging.error(error_msg)
|
||||
|
|
@ -694,17 +699,17 @@ def evaluate_alert(
|
|||
if state != AlertState.ERROR:
|
||||
# Don't validate alert on test runs since it may not be triggered
|
||||
if recipients or slack_channel:
|
||||
deliver_alert(alert_id, recipients, slack_channel)
|
||||
deliver_alert(alert_id, session, recipients, slack_channel)
|
||||
state = AlertState.TRIGGER
|
||||
# Validate during regular workflow and deliver only if triggered
|
||||
elif validate_observations(alert_id, label):
|
||||
deliver_alert(alert_id, recipients, slack_channel)
|
||||
elif validate_observations(alert_id, label, session):
|
||||
deliver_alert(alert_id, session, recipients, slack_channel)
|
||||
state = AlertState.TRIGGER
|
||||
else:
|
||||
state = AlertState.PASS
|
||||
|
||||
db.session.commit()
|
||||
alert = db.session.query(Alert).get(alert_id)
|
||||
session.commit()
|
||||
alert = session.query(Alert).get(alert_id)
|
||||
if state != AlertState.ERROR:
|
||||
alert.last_eval_dttm = dttm_end
|
||||
alert.last_state = state
|
||||
|
|
@ -716,10 +721,10 @@ def evaluate_alert(
|
|||
state=state,
|
||||
)
|
||||
)
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
|
||||
|
||||
def validate_observations(alert_id: int, label: str) -> bool:
|
||||
def validate_observations(alert_id: int, label: str, session: Session) -> bool:
|
||||
"""
|
||||
Runs an alert's validators to check if it should be triggered or not
|
||||
If so, return the name of the validator that returned true
|
||||
|
|
@ -727,7 +732,7 @@ def validate_observations(alert_id: int, label: str) -> bool:
|
|||
|
||||
logger.info("Validating observations for alert <%s:%s>", alert_id, label)
|
||||
|
||||
alert = db.session.query(Alert).get(alert_id)
|
||||
alert = session.query(Alert).get(alert_id)
|
||||
if alert.validators:
|
||||
validator = alert.validators[0]
|
||||
validate = get_validator_function(validator.validator_type)
|
||||
|
|
@ -760,7 +765,11 @@ def next_schedules(
|
|||
|
||||
|
||||
def schedule_window(
|
||||
report_type: str, start_at: datetime, stop_at: datetime, resolution: int
|
||||
report_type: str,
|
||||
start_at: datetime,
|
||||
stop_at: datetime,
|
||||
resolution: int,
|
||||
session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Find all active schedules and schedule celery tasks for
|
||||
|
|
@ -772,8 +781,7 @@ def schedule_window(
|
|||
if not model_cls:
|
||||
return None
|
||||
|
||||
dbsession = db.create_scoped_session()
|
||||
schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
|
||||
schedules = session.query(model_cls).filter(model_cls.active.is_(True))
|
||||
|
||||
for schedule in schedules:
|
||||
logging.info("Processing schedule %s", schedule)
|
||||
|
|
@ -810,7 +818,6 @@ def get_scheduler_action(report_type: str) -> Optional[Callable[..., Any]]:
|
|||
@celery_app.task(name="email_reports.schedule_hourly")
|
||||
def schedule_hourly() -> None:
|
||||
""" Celery beat job meant to be invoked hourly """
|
||||
|
||||
if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]:
|
||||
logger.info("Scheduled email reports not enabled in config")
|
||||
return
|
||||
|
|
@ -820,8 +827,10 @@ def schedule_hourly() -> None:
|
|||
# Get the top of the hour
|
||||
start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0, minute=0)
|
||||
stop_at = start_at + timedelta(seconds=3600)
|
||||
schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution)
|
||||
schedule_window(ScheduleType.slice, start_at, stop_at, resolution)
|
||||
|
||||
with session_scope(nullpool=True) as session:
|
||||
schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution, session)
|
||||
schedule_window(ScheduleType.slice, start_at, stop_at, resolution, session)
|
||||
|
||||
|
||||
@celery_app.task(name="alerts.schedule_check")
|
||||
|
|
@ -833,5 +842,5 @@ def schedule_alerts() -> None:
|
|||
seconds=3600
|
||||
) # process any missed tasks in the past hour
|
||||
stop_at = now + timedelta(seconds=1)
|
||||
|
||||
schedule_window(ScheduleType.alert, start_at, stop_at, resolution)
|
||||
with session_scope(nullpool=True) as session:
|
||||
schedule_window(ScheduleType.alert, start_at, stop_at, resolution, session)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Iterator
|
||||
|
||||
import sqlalchemy as sa
|
||||
from contextlib2 import contextmanager
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
from superset import app, db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Null pool is used for the celery workers due process forking side effects.
|
||||
# For more info see: https://github.com/apache/incubator-superset/issues/10530
|
||||
@contextmanager
|
||||
def session_scope(nullpool: bool) -> Iterator[Session]:
|
||||
"""Provide a transactional scope around a series of operations."""
|
||||
database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
|
||||
if "sqlite" in database_uri:
|
||||
logger.warning(
|
||||
"SQLite Database support for metadata databases will be removed \
|
||||
in a future version of Superset."
|
||||
)
|
||||
if nullpool:
|
||||
engine = sa.create_engine(database_uri, poolclass=NullPool)
|
||||
session_class = sessionmaker()
|
||||
session_class.configure(bind=engine)
|
||||
session = session_class()
|
||||
else:
|
||||
session = db.session()
|
||||
session.commit() # HACK
|
||||
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as ex:
|
||||
session.rollback()
|
||||
logger.exception(ex)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
|
@ -112,37 +112,37 @@ def test_alert_observer(setup_database):
|
|||
|
||||
# Test SQLObserver with int SQL return
|
||||
alert1 = create_alert(dbsession, "SELECT 55")
|
||||
observe(alert1.id)
|
||||
observe(alert1.id, dbsession)
|
||||
assert alert1.sql_observer[0].observations[-1].value == 55.0
|
||||
assert alert1.sql_observer[0].observations[-1].error_msg is None
|
||||
|
||||
# Test SQLObserver with double SQL return
|
||||
alert2 = create_alert(dbsession, "SELECT 30.0 as wage")
|
||||
observe(alert2.id)
|
||||
observe(alert2.id, dbsession)
|
||||
assert alert2.sql_observer[0].observations[-1].value == 30.0
|
||||
assert alert2.sql_observer[0].observations[-1].error_msg is None
|
||||
|
||||
# Test SQLObserver with NULL result
|
||||
alert3 = create_alert(dbsession, "SELECT null as null_result")
|
||||
observe(alert3.id)
|
||||
observe(alert3.id, dbsession)
|
||||
assert alert3.sql_observer[0].observations[-1].value is None
|
||||
assert alert3.sql_observer[0].observations[-1].error_msg is None
|
||||
|
||||
# Test SQLObserver with empty SQL return
|
||||
alert4 = create_alert(dbsession, "SELECT first FROM test_table WHERE first = -1")
|
||||
observe(alert4.id)
|
||||
observe(alert4.id, dbsession)
|
||||
assert alert4.sql_observer[0].observations[-1].value is None
|
||||
assert alert4.sql_observer[0].observations[-1].error_msg is not None
|
||||
|
||||
# Test SQLObserver with str result
|
||||
alert5 = create_alert(dbsession, "SELECT 'test_string' as string_value")
|
||||
observe(alert5.id)
|
||||
observe(alert5.id, dbsession)
|
||||
assert alert5.sql_observer[0].observations[-1].value is None
|
||||
assert alert5.sql_observer[0].observations[-1].error_msg is not None
|
||||
|
||||
# Test SQLObserver with two row result
|
||||
alert6 = create_alert(dbsession, "SELECT first FROM test_table")
|
||||
observe(alert6.id)
|
||||
observe(alert6.id, dbsession)
|
||||
assert alert6.sql_observer[0].observations[-1].value is None
|
||||
assert alert6.sql_observer[0].observations[-1].error_msg is not None
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ def test_alert_observer(setup_database):
|
|||
alert7 = create_alert(
|
||||
dbsession, "SELECT first, second FROM test_table WHERE first = 1"
|
||||
)
|
||||
observe(alert7.id)
|
||||
observe(alert7.id, dbsession)
|
||||
assert alert7.sql_observer[0].observations[-1].value is None
|
||||
assert alert7.sql_observer[0].observations[-1].error_msg is not None
|
||||
|
||||
|
|
@ -161,22 +161,22 @@ def test_evaluate_alert(mock_deliver_alert, setup_database):
|
|||
|
||||
# Test error with Observer SQL statement
|
||||
alert1 = create_alert(dbsession, "$%^&")
|
||||
evaluate_alert(alert1.id, alert1.label)
|
||||
evaluate_alert(alert1.id, alert1.label, dbsession)
|
||||
assert alert1.logs[-1].state == AlertState.ERROR
|
||||
|
||||
# Test error with alert lacking observer
|
||||
alert2 = dbsession.query(Alert).filter_by(label="No Observer").one()
|
||||
evaluate_alert(alert2.id, alert2.label)
|
||||
evaluate_alert(alert2.id, alert2.label, dbsession)
|
||||
assert alert2.logs[-1].state == AlertState.ERROR
|
||||
|
||||
# Test pass on alert lacking validator
|
||||
alert3 = create_alert(dbsession, "SELECT 55")
|
||||
evaluate_alert(alert3.id, alert3.label)
|
||||
evaluate_alert(alert3.id, alert3.label, dbsession)
|
||||
assert alert3.logs[-1].state == AlertState.PASS
|
||||
|
||||
# Test triggering successful alert
|
||||
alert4 = create_alert(dbsession, "SELECT 55", "not null", "{}")
|
||||
evaluate_alert(alert4.id, alert4.label)
|
||||
evaluate_alert(alert4.id, alert4.label, dbsession)
|
||||
assert mock_deliver_alert.call_count == 1
|
||||
assert alert4.logs[-1].state == AlertState.TRIGGER
|
||||
|
||||
|
|
@ -214,17 +214,17 @@ def test_not_null_validator(setup_database):
|
|||
|
||||
# Test passing SQLObserver with 'null' SQL result
|
||||
alert1 = create_alert(dbsession, "SELECT 0")
|
||||
observe(alert1.id)
|
||||
observe(alert1.id, dbsession)
|
||||
assert not_null_validator(alert1.sql_observer[0], "{}") is False
|
||||
|
||||
# Test passing SQLObserver with empty SQL result
|
||||
alert2 = create_alert(dbsession, "SELECT first FROM test_table WHERE first = -1")
|
||||
observe(alert2.id)
|
||||
observe(alert2.id, dbsession)
|
||||
assert not_null_validator(alert2.sql_observer[0], "{}") is False
|
||||
|
||||
# Test triggering alert with non-null SQL result
|
||||
alert3 = create_alert(dbsession, "SELECT 55")
|
||||
observe(alert3.id)
|
||||
observe(alert3.id, dbsession)
|
||||
assert not_null_validator(alert3.sql_observer[0], "{}") is True
|
||||
|
||||
|
||||
|
|
@ -233,7 +233,7 @@ def test_operator_validator(setup_database):
|
|||
|
||||
# Test passing SQLObserver with empty SQL result
|
||||
alert1 = create_alert(dbsession, "SELECT first FROM test_table WHERE first = -1")
|
||||
observe(alert1.id)
|
||||
observe(alert1.id, dbsession)
|
||||
assert (
|
||||
operator_validator(alert1.sql_observer[0], '{"op": ">=", "threshold": 60}')
|
||||
is False
|
||||
|
|
@ -241,7 +241,7 @@ def test_operator_validator(setup_database):
|
|||
|
||||
# Test passing SQLObserver with result that doesn't pass a greater than threshold
|
||||
alert2 = create_alert(dbsession, "SELECT 55")
|
||||
observe(alert2.id)
|
||||
observe(alert2.id, dbsession)
|
||||
assert (
|
||||
operator_validator(alert2.sql_observer[0], '{"op": ">=", "threshold": 60}')
|
||||
is False
|
||||
|
|
@ -283,23 +283,23 @@ def test_validate_observations(setup_database):
|
|||
|
||||
# Test False on alert with no validator
|
||||
alert1 = create_alert(dbsession, "SELECT 55")
|
||||
assert validate_observations(alert1.id, alert1.label) is False
|
||||
assert validate_observations(alert1.id, alert1.label, dbsession) is False
|
||||
|
||||
# Test False on alert with no observations
|
||||
alert2 = create_alert(dbsession, "SELECT 55", "not null", "{}")
|
||||
assert validate_observations(alert2.id, alert2.label) is False
|
||||
assert validate_observations(alert2.id, alert2.label, dbsession) is False
|
||||
|
||||
# Test False on alert that shouldnt be triggered
|
||||
alert3 = create_alert(dbsession, "SELECT 0", "not null", "{}")
|
||||
observe(alert3.id)
|
||||
assert validate_observations(alert3.id, alert3.label) is False
|
||||
observe(alert3.id, dbsession)
|
||||
assert validate_observations(alert3.id, alert3.label, dbsession) is False
|
||||
|
||||
# Test True on alert that should be triggered
|
||||
alert4 = create_alert(
|
||||
dbsession, "SELECT 55", "operator", '{"op": "<=", "threshold": 60}'
|
||||
)
|
||||
observe(alert4.id)
|
||||
assert validate_observations(alert4.id, alert4.label) is True
|
||||
observe(alert4.id, dbsession)
|
||||
assert validate_observations(alert4.id, alert4.label, dbsession) is True
|
||||
|
||||
|
||||
@patch("superset.tasks.slack_util.WebClient.files_upload")
|
||||
|
|
@ -311,7 +311,7 @@ def test_deliver_alert_screenshot(
|
|||
):
|
||||
dbsession = setup_database
|
||||
alert = create_alert(dbsession, "SELECT 55", "not null", "{}")
|
||||
observe(alert.id)
|
||||
observe(alert.id, dbsession)
|
||||
|
||||
screenshot = read_fixture("sample.png")
|
||||
screenshot_mock.return_value = screenshot
|
||||
|
|
@ -322,7 +322,7 @@ def test_deliver_alert_screenshot(
|
|||
f"http://0.0.0.0:8080/superset/slice/{alert.slice_id}/",
|
||||
]
|
||||
|
||||
deliver_alert(alert_id=alert.id)
|
||||
deliver_alert(alert.id, dbsession)
|
||||
assert email_mock.call_args[1]["images"]["screenshot"] == screenshot
|
||||
assert file_upload_mock.call_args[1] == {
|
||||
"channels": alert.slack_channel,
|
||||
|
|
|
|||
|
|
@ -366,6 +366,7 @@ class TestSchedules(SupersetTestCase):
|
|||
schedule.delivery_type,
|
||||
schedule.email_format,
|
||||
schedule.deliver_as_group,
|
||||
db.session,
|
||||
)
|
||||
mtime.sleep.assert_called_once()
|
||||
driver.screenshot.assert_not_called()
|
||||
|
|
@ -418,6 +419,7 @@ class TestSchedules(SupersetTestCase):
|
|||
schedule.delivery_type,
|
||||
schedule.email_format,
|
||||
schedule.deliver_as_group,
|
||||
db.session,
|
||||
)
|
||||
|
||||
mtime.sleep.assert_called_once()
|
||||
|
|
@ -466,6 +468,7 @@ class TestSchedules(SupersetTestCase):
|
|||
schedule.delivery_type,
|
||||
schedule.email_format,
|
||||
schedule.deliver_as_group,
|
||||
db.session,
|
||||
)
|
||||
|
||||
send_email_smtp.assert_called_once()
|
||||
|
|
@ -510,6 +513,7 @@ class TestSchedules(SupersetTestCase):
|
|||
schedule.delivery_type,
|
||||
schedule.email_format,
|
||||
schedule.deliver_as_group,
|
||||
db.session,
|
||||
)
|
||||
|
||||
send_email_smtp.assert_called_once()
|
||||
|
|
|
|||
Loading…
Reference in New Issue