fix: use nullpool in the celery workers (#10819)

* Use nullpool in the celery workers

* Address comments

Co-authored-by: bogdan kyryliuk <bogdankyryliuk@dropbox.com>
This commit is contained in:
Bogdan 2020-09-10 13:29:57 -07:00 committed by GitHub
parent dd7f3d5402
commit ac2937a6c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 235 additions and 194 deletions

View File

@ -34,6 +34,7 @@ from superset import app, appbuilder, security_manager
from superset.app import create_app
from superset.extensions import celery_app, db
from superset.utils import core as utils
from superset.utils.celery import session_scope
from superset.utils.urls import get_url_path
logger = logging.getLogger(__name__)
@ -619,6 +620,11 @@ def alert() -> None:
from superset.tasks.schedules import schedule_window
click.secho("Processing one alert loop", fg="green")
schedule_window(
ScheduleType.alert, datetime.now() - timedelta(1000), datetime.now(), 6000
)
with session_scope(nullpool=True) as session:
schedule_window(
report_type=ScheduleType.alert,
start_at=datetime.now() - timedelta(1000),
stop_at=datetime.now(),
resolution=6000,
session=session,
)

View File

@ -19,33 +19,25 @@ import uuid
from contextlib import closing
from datetime import datetime
from sys import getsizeof
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Tuple, Union
import backoff
import msgpack
import pyarrow as pa
import simplejson as json
import sqlalchemy
from celery.exceptions import SoftTimeLimitExceeded
from celery.task.base import Task
from contextlib2 import contextmanager
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.orm import Session
from superset import (
app,
db,
results_backend,
results_backend_use_msgpack,
security_manager,
)
from superset import app, results_backend, results_backend_use_msgpack, security_manager
from superset.dataframe import df_to_records
from superset.db_engine_specs import BaseEngineSpec
from superset.extensions import celery_app
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
from superset.utils.celery import session_scope
from superset.utils.core import (
json_iso_dttm_ser,
QuerySource,
@ -121,35 +113,6 @@ def get_query(query_id: int, session: Session) -> Query:
raise SqlLabException("Failed at getting query")
@contextmanager
def session_scope(nullpool: bool) -> Iterator[Session]:
"""Provide a transactional scope around a series of operations."""
database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
if "sqlite" in database_uri:
logger.warning(
"SQLite Database support for metadata databases will be removed \
in a future version of Superset."
)
if nullpool:
engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool)
session_class = sessionmaker()
session_class.configure(bind=engine)
session = session_class()
else:
session = db.session()
session.commit() # HACK
try:
yield session
session.commit()
except Exception as ex:
session.rollback()
logger.exception(ex)
raise
finally:
session.close()
@celery_app.task(
name="sql_lab.get_sql_results",
bind=True,

View File

@ -20,22 +20,24 @@ from datetime import datetime
from typing import Optional
import pandas as pd
from sqlalchemy.orm import Session
from superset import db
from superset.models.alerts import Alert, SQLObservation
from superset.sql_parse import ParsedQuery
logger = logging.getLogger("tasks.email_reports")
def observe(alert_id: int) -> Optional[str]:
# Session needs to be passed along in the celery workers and db.session cannot be used.
# For more info see: https://github.com/apache/incubator-superset/issues/10530
def observe(alert_id: int, session: Session) -> Optional[str]:
"""
Runs the SQL query in an alert's SQLObserver and then
stores the result in a SQLObservation.
Returns an error message if the observer value was not valid
"""
alert = db.session.query(Alert).filter_by(id=alert_id).one()
alert = session.query(Alert).filter_by(id=alert_id).one()
sql_observer = alert.sql_observer[0]
value = None
@ -57,8 +59,8 @@ def observe(alert_id: int) -> Optional[str]:
error_msg=error_msg,
)
db.session.add(observation)
db.session.commit()
session.add(observation)
session.commit()
return error_msg

View File

@ -47,8 +47,9 @@ from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
from selenium.webdriver.remote.webdriver import WebDriver
from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError
from sqlalchemy.orm import Session
from superset import app, db, security_manager, thumbnail_cache
from superset import app, security_manager, thumbnail_cache
from superset.extensions import celery_app, machine_auth_provider_factory
from superset.models.alerts import Alert, AlertLog
from superset.models.dashboard import Dashboard
@ -62,6 +63,7 @@ from superset.models.slice import Slice
from superset.tasks.alerts.observer import observe
from superset.tasks.alerts.validator import get_validator_function
from superset.tasks.slack_util import deliver_slack_msg
from superset.utils.celery import session_scope
from superset.utils.core import get_email_address_list, send_email_smtp
from superset.utils.screenshots import ChartScreenshot, WebDriverProxy
from superset.utils.urls import get_url_path
@ -225,7 +227,7 @@ def destroy_webdriver(
pass
def deliver_dashboard(
def deliver_dashboard( # pylint: disable=too-many-locals
dashboard_id: int,
recipients: Optional[str],
slack_channel: Optional[str],
@ -236,69 +238,70 @@ def deliver_dashboard(
"""
Given a schedule, delivery the dashboard as an email report
"""
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one()
with session_scope(nullpool=True) as session:
dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one()
dashboard_url = _get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
dashboard_url_user_friendly = _get_url_path(
"Superset.dashboard", user_friendly=True, dashboard_id_or_slug=dashboard.id
)
# Create a driver, fetch the page, wait for the page to render
driver = create_webdriver()
window = config["WEBDRIVER_WINDOW"]["dashboard"]
driver.set_window_size(*window)
driver.get(dashboard_url)
time.sleep(EMAIL_PAGE_RENDER_WAIT)
# Set up a function to retry once for the element.
# This is buggy in certain selenium versions with firefox driver
get_element = getattr(driver, "find_element_by_class_name")
element = retry_call(
get_element, fargs=["grid-container"], tries=2, delay=EMAIL_PAGE_RENDER_WAIT
)
try:
screenshot = element.screenshot_as_png
except WebDriverException:
# Some webdrivers do not support screenshots for elements.
# In such cases, take a screenshot of the entire page.
screenshot = driver.screenshot() # pylint: disable=no-member
finally:
destroy_webdriver(driver)
# Generate the email body and attachments
report_content = _generate_report_content(
delivery_type,
screenshot,
dashboard.dashboard_title,
dashboard_url_user_friendly,
)
subject = __(
"%(prefix)s %(title)s",
prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"],
title=dashboard.dashboard_title,
)
if recipients:
_deliver_email(
recipients,
deliver_as_group,
subject,
report_content.body,
report_content.data,
report_content.images,
dashboard_url = _get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
if slack_channel:
deliver_slack_msg(
slack_channel,
subject,
report_content.slack_message,
report_content.slack_attachment,
dashboard_url_user_friendly = _get_url_path(
"Superset.dashboard", user_friendly=True, dashboard_id_or_slug=dashboard.id
)
# Create a driver, fetch the page, wait for the page to render
driver = create_webdriver()
window = config["WEBDRIVER_WINDOW"]["dashboard"]
driver.set_window_size(*window)
driver.get(dashboard_url)
time.sleep(EMAIL_PAGE_RENDER_WAIT)
# Set up a function to retry once for the element.
# This is buggy in certain selenium versions with firefox driver
get_element = getattr(driver, "find_element_by_class_name")
element = retry_call(
get_element, fargs=["grid-container"], tries=2, delay=EMAIL_PAGE_RENDER_WAIT
)
try:
screenshot = element.screenshot_as_png
except WebDriverException:
# Some webdrivers do not support screenshots for elements.
# In such cases, take a screenshot of the entire page.
screenshot = driver.screenshot() # pylint: disable=no-member
finally:
destroy_webdriver(driver)
# Generate the email body and attachments
report_content = _generate_report_content(
delivery_type,
screenshot,
dashboard.dashboard_title,
dashboard_url_user_friendly,
)
subject = __(
"%(prefix)s %(title)s",
prefix=config["EMAIL_REPORTS_SUBJECT_PREFIX"],
title=dashboard.dashboard_title,
)
if recipients:
_deliver_email(
recipients,
deliver_as_group,
subject,
report_content.body,
report_content.data,
report_content.images,
)
if slack_channel:
deliver_slack_msg(
slack_channel,
subject,
report_content.slack_message,
report_content.slack_attachment,
)
def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportContent:
slice_url = _get_url_path(
@ -362,8 +365,8 @@ def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportConte
return ReportContent(body, data, None, slack_message, content)
def _get_slice_screenshot(slice_id: int) -> ScreenshotData:
slice_obj = db.session.query(Slice).get(slice_id)
def _get_slice_screenshot(slice_id: int, session: Session) -> ScreenshotData:
slice_obj = session.query(Slice).get(slice_id)
chart_url = get_url_path("Superset.slice", slice_id=slice_obj.id, standalone="true")
screenshot = ChartScreenshot(chart_url, slice_obj.digest)
@ -376,7 +379,7 @@ def _get_slice_screenshot(slice_id: int) -> ScreenshotData:
user=user, cache=thumbnail_cache, force=True,
)
db.session.commit()
session.commit()
return ScreenshotData(image_url, image_data)
@ -427,11 +430,12 @@ def deliver_slice( # pylint: disable=too-many-arguments
delivery_type: EmailDeliveryType,
email_format: SliceEmailReportFormat,
deliver_as_group: bool,
session: Session,
) -> None:
"""
Given a schedule, delivery the slice as an email report
"""
slc = db.session.query(Slice).filter_by(id=slice_id).one()
slc = session.query(Slice).filter_by(id=slice_id).one()
if email_format == SliceEmailReportFormat.data:
report_content = _get_slice_data(slc, delivery_type)
@ -477,38 +481,42 @@ def schedule_email_report( # pylint: disable=unused-argument
slack_channel: Optional[str] = None,
) -> None:
model_cls = get_scheduler_model(report_type)
schedule = db.create_scoped_session().query(model_cls).get(schedule_id)
with session_scope(nullpool=True) as session:
schedule = session.query(model_cls).get(schedule_id)
# The user may have disabled the schedule. If so, ignore this
if not schedule or not schedule.active:
logger.info("Ignoring deactivated schedule")
return
# The user may have disabled the schedule. If so, ignore this
if not schedule or not schedule.active:
logger.info("Ignoring deactivated schedule")
return
recipients = recipients or schedule.recipients
slack_channel = slack_channel or schedule.slack_channel
logger.info(
"Starting report for slack: %s and recipients: %s.", slack_channel, recipients
)
if report_type == ScheduleType.dashboard:
deliver_dashboard(
schedule.dashboard_id,
recipients,
recipients = recipients or schedule.recipients
slack_channel = slack_channel or schedule.slack_channel
logger.info(
"Starting report for slack: %s and recipients: %s.",
slack_channel,
schedule.delivery_type,
schedule.deliver_as_group,
)
elif report_type == ScheduleType.slice:
deliver_slice(
schedule.slice_id,
recipients,
slack_channel,
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
)
else:
raise RuntimeError("Unknown report type")
if report_type == ScheduleType.dashboard:
deliver_dashboard(
schedule.dashboard_id,
recipients,
slack_channel,
schedule.delivery_type,
schedule.deliver_as_group,
)
elif report_type == ScheduleType.slice:
deliver_slice(
schedule.slice_id,
recipients,
slack_channel,
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
session,
)
else:
raise RuntimeError("Unknown report type")
@celery_app.task(
@ -529,9 +537,8 @@ def schedule_alert_query( # pylint: disable=unused-argument
slack_channel: Optional[str] = None,
) -> None:
model_cls = get_scheduler_model(report_type)
try:
schedule = db.session.query(model_cls).get(schedule_id)
with session_scope(nullpool=True) as session:
schedule = session.query(model_cls).get(schedule_id)
# The user may have disabled the schedule. If so, ignore this
if not schedule or not schedule.active:
@ -539,15 +546,11 @@ def schedule_alert_query( # pylint: disable=unused-argument
return
if report_type == ScheduleType.alert:
evaluate_alert(schedule.id, schedule.label, recipients, slack_channel)
evaluate_alert(
schedule.id, schedule.label, session, recipients, slack_channel
)
else:
raise RuntimeError("Unknown report type")
except NoSuchColumnError as column_error:
stats_logger.incr("run_alert_task.error.nosuchcolumnerror")
raise column_error
except ResourceClosedError as resource_error:
stats_logger.incr("run_alert_task.error.resourceclosederror")
raise resource_error
class AlertState:
@ -558,6 +561,7 @@ class AlertState:
def deliver_alert(
alert_id: int,
session: Session,
recipients: Optional[str] = None,
slack_channel: Optional[str] = None,
) -> None:
@ -566,7 +570,7 @@ def deliver_alert(
to its respective email and slack recipients
"""
alert = db.session.query(Alert).get(alert_id)
alert = session.query(Alert).get(alert_id)
logging.info("Triggering alert: %s", alert)
@ -588,7 +592,7 @@ def deliver_alert(
str(alert.observations[-1].value),
validation_error_message,
_get_url_path("AlertModelView.show", user_friendly=True, pk=alert_id),
_get_slice_screenshot(alert.slice.id),
_get_slice_screenshot(alert.slice.id, session),
)
else:
# TODO: dashboard delivery!
@ -668,6 +672,7 @@ def deliver_slack_alert(alert_content: AlertContent, slack_channel: str) -> None
def evaluate_alert(
alert_id: int,
label: str,
session: Session,
recipients: Optional[str] = None,
slack_channel: Optional[str] = None,
) -> None:
@ -680,7 +685,7 @@ def evaluate_alert(
try:
logger.info("Querying observers for alert <%s:%s>", alert_id, label)
error_msg = observe(alert_id)
error_msg = observe(alert_id, session)
if error_msg:
state = AlertState.ERROR
logging.error(error_msg)
@ -694,17 +699,17 @@ def evaluate_alert(
if state != AlertState.ERROR:
# Don't validate alert on test runs since it may not be triggered
if recipients or slack_channel:
deliver_alert(alert_id, recipients, slack_channel)
deliver_alert(alert_id, session, recipients, slack_channel)
state = AlertState.TRIGGER
# Validate during regular workflow and deliver only if triggered
elif validate_observations(alert_id, label):
deliver_alert(alert_id, recipients, slack_channel)
elif validate_observations(alert_id, label, session):
deliver_alert(alert_id, session, recipients, slack_channel)
state = AlertState.TRIGGER
else:
state = AlertState.PASS
db.session.commit()
alert = db.session.query(Alert).get(alert_id)
session.commit()
alert = session.query(Alert).get(alert_id)
if state != AlertState.ERROR:
alert.last_eval_dttm = dttm_end
alert.last_state = state
@ -716,10 +721,10 @@ def evaluate_alert(
state=state,
)
)
db.session.commit()
session.commit()
def validate_observations(alert_id: int, label: str) -> bool:
def validate_observations(alert_id: int, label: str, session: Session) -> bool:
"""
Runs an alert's validators to check if it should be triggered or not
If so, return the name of the validator that returned true
@ -727,7 +732,7 @@ def validate_observations(alert_id: int, label: str) -> bool:
logger.info("Validating observations for alert <%s:%s>", alert_id, label)
alert = db.session.query(Alert).get(alert_id)
alert = session.query(Alert).get(alert_id)
if alert.validators:
validator = alert.validators[0]
validate = get_validator_function(validator.validator_type)
@ -760,7 +765,11 @@ def next_schedules(
def schedule_window(
report_type: str, start_at: datetime, stop_at: datetime, resolution: int
report_type: str,
start_at: datetime,
stop_at: datetime,
resolution: int,
session: Session,
) -> None:
"""
Find all active schedules and schedule celery tasks for
@ -772,8 +781,7 @@ def schedule_window(
if not model_cls:
return None
dbsession = db.create_scoped_session()
schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
schedules = session.query(model_cls).filter(model_cls.active.is_(True))
for schedule in schedules:
logging.info("Processing schedule %s", schedule)
@ -810,7 +818,6 @@ def get_scheduler_action(report_type: str) -> Optional[Callable[..., Any]]:
@celery_app.task(name="email_reports.schedule_hourly")
def schedule_hourly() -> None:
""" Celery beat job meant to be invoked hourly """
if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]:
logger.info("Scheduled email reports not enabled in config")
return
@ -820,8 +827,10 @@ def schedule_hourly() -> None:
# Get the top of the hour
start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0, minute=0)
stop_at = start_at + timedelta(seconds=3600)
schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution)
schedule_window(ScheduleType.slice, start_at, stop_at, resolution)
with session_scope(nullpool=True) as session:
schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution, session)
schedule_window(ScheduleType.slice, start_at, stop_at, resolution, session)
@celery_app.task(name="alerts.schedule_check")
@ -833,5 +842,5 @@ def schedule_alerts() -> None:
seconds=3600
) # process any missed tasks in the past hour
stop_at = now + timedelta(seconds=1)
schedule_window(ScheduleType.alert, start_at, stop_at, resolution)
with session_scope(nullpool=True) as session:
schedule_window(ScheduleType.alert, start_at, stop_at, resolution, session)

57
superset/utils/celery.py Normal file
View File

@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Iterator
import sqlalchemy as sa
from contextlib2 import contextmanager
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import NullPool
from superset import app, db
logger = logging.getLogger(__name__)
# Null pool is used for the celery workers due process forking side effects.
# For more info see: https://github.com/apache/incubator-superset/issues/10530
@contextmanager
def session_scope(nullpool: bool) -> Iterator[Session]:
"""Provide a transactional scope around a series of operations."""
database_uri = app.config["SQLALCHEMY_DATABASE_URI"]
if "sqlite" in database_uri:
logger.warning(
"SQLite Database support for metadata databases will be removed \
in a future version of Superset."
)
if nullpool:
engine = sa.create_engine(database_uri, poolclass=NullPool)
session_class = sessionmaker()
session_class.configure(bind=engine)
session = session_class()
else:
session = db.session()
session.commit() # HACK
try:
yield session
session.commit()
except Exception as ex:
session.rollback()
logger.exception(ex)
raise
finally:
session.close()

View File

@ -112,37 +112,37 @@ def test_alert_observer(setup_database):
# Test SQLObserver with int SQL return
alert1 = create_alert(dbsession, "SELECT 55")
observe(alert1.id)
observe(alert1.id, dbsession)
assert alert1.sql_observer[0].observations[-1].value == 55.0
assert alert1.sql_observer[0].observations[-1].error_msg is None
# Test SQLObserver with double SQL return
alert2 = create_alert(dbsession, "SELECT 30.0 as wage")
observe(alert2.id)
observe(alert2.id, dbsession)
assert alert2.sql_observer[0].observations[-1].value == 30.0
assert alert2.sql_observer[0].observations[-1].error_msg is None
# Test SQLObserver with NULL result
alert3 = create_alert(dbsession, "SELECT null as null_result")
observe(alert3.id)
observe(alert3.id, dbsession)
assert alert3.sql_observer[0].observations[-1].value is None
assert alert3.sql_observer[0].observations[-1].error_msg is None
# Test SQLObserver with empty SQL return
alert4 = create_alert(dbsession, "SELECT first FROM test_table WHERE first = -1")
observe(alert4.id)
observe(alert4.id, dbsession)
assert alert4.sql_observer[0].observations[-1].value is None
assert alert4.sql_observer[0].observations[-1].error_msg is not None
# Test SQLObserver with str result
alert5 = create_alert(dbsession, "SELECT 'test_string' as string_value")
observe(alert5.id)
observe(alert5.id, dbsession)
assert alert5.sql_observer[0].observations[-1].value is None
assert alert5.sql_observer[0].observations[-1].error_msg is not None
# Test SQLObserver with two row result
alert6 = create_alert(dbsession, "SELECT first FROM test_table")
observe(alert6.id)
observe(alert6.id, dbsession)
assert alert6.sql_observer[0].observations[-1].value is None
assert alert6.sql_observer[0].observations[-1].error_msg is not None
@ -150,7 +150,7 @@ def test_alert_observer(setup_database):
alert7 = create_alert(
dbsession, "SELECT first, second FROM test_table WHERE first = 1"
)
observe(alert7.id)
observe(alert7.id, dbsession)
assert alert7.sql_observer[0].observations[-1].value is None
assert alert7.sql_observer[0].observations[-1].error_msg is not None
@ -161,22 +161,22 @@ def test_evaluate_alert(mock_deliver_alert, setup_database):
# Test error with Observer SQL statement
alert1 = create_alert(dbsession, "$%^&")
evaluate_alert(alert1.id, alert1.label)
evaluate_alert(alert1.id, alert1.label, dbsession)
assert alert1.logs[-1].state == AlertState.ERROR
# Test error with alert lacking observer
alert2 = dbsession.query(Alert).filter_by(label="No Observer").one()
evaluate_alert(alert2.id, alert2.label)
evaluate_alert(alert2.id, alert2.label, dbsession)
assert alert2.logs[-1].state == AlertState.ERROR
# Test pass on alert lacking validator
alert3 = create_alert(dbsession, "SELECT 55")
evaluate_alert(alert3.id, alert3.label)
evaluate_alert(alert3.id, alert3.label, dbsession)
assert alert3.logs[-1].state == AlertState.PASS
# Test triggering successful alert
alert4 = create_alert(dbsession, "SELECT 55", "not null", "{}")
evaluate_alert(alert4.id, alert4.label)
evaluate_alert(alert4.id, alert4.label, dbsession)
assert mock_deliver_alert.call_count == 1
assert alert4.logs[-1].state == AlertState.TRIGGER
@ -214,17 +214,17 @@ def test_not_null_validator(setup_database):
# Test passing SQLObserver with 'null' SQL result
alert1 = create_alert(dbsession, "SELECT 0")
observe(alert1.id)
observe(alert1.id, dbsession)
assert not_null_validator(alert1.sql_observer[0], "{}") is False
# Test passing SQLObserver with empty SQL result
alert2 = create_alert(dbsession, "SELECT first FROM test_table WHERE first = -1")
observe(alert2.id)
observe(alert2.id, dbsession)
assert not_null_validator(alert2.sql_observer[0], "{}") is False
# Test triggering alert with non-null SQL result
alert3 = create_alert(dbsession, "SELECT 55")
observe(alert3.id)
observe(alert3.id, dbsession)
assert not_null_validator(alert3.sql_observer[0], "{}") is True
@ -233,7 +233,7 @@ def test_operator_validator(setup_database):
# Test passing SQLObserver with empty SQL result
alert1 = create_alert(dbsession, "SELECT first FROM test_table WHERE first = -1")
observe(alert1.id)
observe(alert1.id, dbsession)
assert (
operator_validator(alert1.sql_observer[0], '{"op": ">=", "threshold": 60}')
is False
@ -241,7 +241,7 @@ def test_operator_validator(setup_database):
# Test passing SQLObserver with result that doesn't pass a greater than threshold
alert2 = create_alert(dbsession, "SELECT 55")
observe(alert2.id)
observe(alert2.id, dbsession)
assert (
operator_validator(alert2.sql_observer[0], '{"op": ">=", "threshold": 60}')
is False
@ -283,23 +283,23 @@ def test_validate_observations(setup_database):
# Test False on alert with no validator
alert1 = create_alert(dbsession, "SELECT 55")
assert validate_observations(alert1.id, alert1.label) is False
assert validate_observations(alert1.id, alert1.label, dbsession) is False
# Test False on alert with no observations
alert2 = create_alert(dbsession, "SELECT 55", "not null", "{}")
assert validate_observations(alert2.id, alert2.label) is False
assert validate_observations(alert2.id, alert2.label, dbsession) is False
# Test False on alert that shouldnt be triggered
alert3 = create_alert(dbsession, "SELECT 0", "not null", "{}")
observe(alert3.id)
assert validate_observations(alert3.id, alert3.label) is False
observe(alert3.id, dbsession)
assert validate_observations(alert3.id, alert3.label, dbsession) is False
# Test True on alert that should be triggered
alert4 = create_alert(
dbsession, "SELECT 55", "operator", '{"op": "<=", "threshold": 60}'
)
observe(alert4.id)
assert validate_observations(alert4.id, alert4.label) is True
observe(alert4.id, dbsession)
assert validate_observations(alert4.id, alert4.label, dbsession) is True
@patch("superset.tasks.slack_util.WebClient.files_upload")
@ -311,7 +311,7 @@ def test_deliver_alert_screenshot(
):
dbsession = setup_database
alert = create_alert(dbsession, "SELECT 55", "not null", "{}")
observe(alert.id)
observe(alert.id, dbsession)
screenshot = read_fixture("sample.png")
screenshot_mock.return_value = screenshot
@ -322,7 +322,7 @@ def test_deliver_alert_screenshot(
f"http://0.0.0.0:8080/superset/slice/{alert.slice_id}/",
]
deliver_alert(alert_id=alert.id)
deliver_alert(alert.id, dbsession)
assert email_mock.call_args[1]["images"]["screenshot"] == screenshot
assert file_upload_mock.call_args[1] == {
"channels": alert.slack_channel,

View File

@ -366,6 +366,7 @@ class TestSchedules(SupersetTestCase):
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
db.session,
)
mtime.sleep.assert_called_once()
driver.screenshot.assert_not_called()
@ -418,6 +419,7 @@ class TestSchedules(SupersetTestCase):
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
db.session,
)
mtime.sleep.assert_called_once()
@ -466,6 +468,7 @@ class TestSchedules(SupersetTestCase):
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
db.session,
)
send_email_smtp.assert_called_once()
@ -510,6 +513,7 @@ class TestSchedules(SupersetTestCase):
schedule.delivery_type,
schedule.email_format,
schedule.deliver_as_group,
db.session,
)
send_email_smtp.assert_called_once()