feat(alerts): Select tabs to send backend (#17749)

* Adding the extra config and validation

* wip

* reports working

* Tests working

* fix type

* Fix lint errors

* Fixing type issues

* add licence header

* fix the fixture deleting problem

* scope to session

* fix integration test

* fix review comments

* fix review comments patch 2

Co-authored-by: Grace Guo <grace.guo@airbnb.com>
This commit is contained in:
Ajay M 2022-01-11 10:48:50 -08:00 committed by GitHub
parent 46715b295c
commit bdc35a2214
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 318 additions and 66 deletions

View File

@ -16,6 +16,8 @@
# under the License.
"""A collection of ORM sqlalchemy models for Superset"""
import enum
import json
from typing import Any, Dict, Optional
from cron_descriptor import get_description
from flask_appbuilder import Model
@ -31,7 +33,7 @@ from sqlalchemy import (
Table,
Text,
)
from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm import backref, relationship, validates
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy_utils import UUIDType
@ -158,6 +160,13 @@ class ReportSchedule(Model, AuditMixinNullable):
def crontab_humanized(self) -> str:
return get_description(self.crontab)
@validates("extra")
# pylint: disable=unused-argument,no-self-use
def validate_extra(self, key: str, value: Dict[Any, Any]) -> Optional[str]:
if value is not None:
return json.dumps(value)
return None
class ReportRecipients(Model, AuditMixinNullable):
"""

View File

@ -88,6 +88,7 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
# Validate chart or dashboard relations
self.validate_chart_dashboard(exceptions)
self._validate_report_extra(exceptions)
# Validate that each chart or dashboard only has one report with
# the respective creation method.
@ -113,3 +114,22 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
exception = ReportScheduleInvalidError()
exception.add_list(exceptions)
raise exception
def _validate_report_extra(self, exceptions: List[ValidationError]) -> None:
extra = self._properties.get("extra")
dashboard = self._properties.get("dashboard")
if extra is None or dashboard is None:
return
dashboard_tab_ids = extra.get("dashboard_tab_ids")
if dashboard_tab_ids is None:
return
position_data = json.loads(dashboard.position_json)
invalid_tab_ids = [
tab_id for tab_id in dashboard_tab_ids if tab_id not in position_data
]
if invalid_tab_ids:
exceptions.append(
ValidationError(f"Invalid tab IDs selected: {invalid_tab_ids}", "extra")
)

View File

@ -187,41 +187,55 @@ class BaseReportState:
raise ReportScheduleSelleniumUserNotFoundError()
return user
def _get_screenshot(self) -> bytes:
def _get_screenshots(self) -> List[bytes]:
"""
Get a chart or dashboard screenshot
Get chart or dashboard screenshots
:raises: ReportScheduleScreenshotFailedError
"""
screenshot: Optional[BaseScreenshot] = None
image_data = []
screenshots: List[BaseScreenshot] = []
if self._report_schedule.chart:
url = self._get_url()
logger.info("Screenshotting chart at %s", url)
screenshot = ChartScreenshot(
url,
self._report_schedule.chart.digest,
window_size=app.config["WEBDRIVER_WINDOW"]["slice"],
thumb_size=app.config["WEBDRIVER_WINDOW"]["slice"],
)
screenshots = [
ChartScreenshot(
url,
self._report_schedule.chart.digest,
window_size=app.config["WEBDRIVER_WINDOW"]["slice"],
thumb_size=app.config["WEBDRIVER_WINDOW"]["slice"],
)
]
else:
url = self._get_url()
logger.info("Screenshotting dashboard at %s", url)
screenshot = DashboardScreenshot(
url,
self._report_schedule.dashboard.digest,
window_size=app.config["WEBDRIVER_WINDOW"]["dashboard"],
thumb_size=app.config["WEBDRIVER_WINDOW"]["dashboard"],
tabs: Optional[List[str]] = json.loads(self._report_schedule.extra).get(
"dashboard_tab_ids", None
)
dashboard_base_url = self._get_url()
if tabs is None:
urls = [dashboard_base_url]
else:
urls = [f"{dashboard_base_url}#{tab_id}" for tab_id in tabs]
screenshots = [
DashboardScreenshot(
url,
self._report_schedule.dashboard.digest,
window_size=app.config["WEBDRIVER_WINDOW"]["dashboard"],
thumb_size=app.config["WEBDRIVER_WINDOW"]["dashboard"],
)
for url in urls
]
user = self._get_user()
try:
image_data = screenshot.get_screenshot(user=user)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while taking a screenshot.")
raise ReportScheduleScreenshotTimeout() from ex
except Exception as ex:
raise ReportScheduleScreenshotFailedError(
f"Failed taking a screenshot {str(ex)}"
) from ex
for screenshot in screenshots:
try:
image = screenshot.get_screenshot(user=user)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while taking a screenshot.")
raise ReportScheduleScreenshotTimeout() from ex
except Exception as ex:
raise ReportScheduleScreenshotFailedError(
f"Failed taking a screenshot {str(ex)}"
) from ex
if image is not None:
image_data.append(image)
if not image_data:
raise ReportScheduleScreenshotFailedError()
return image_data
@ -285,7 +299,7 @@ class BaseReportState:
context.
"""
try:
self._get_screenshot()
self._get_screenshots()
except (
ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout,
@ -305,14 +319,14 @@ class BaseReportState:
csv_data = None
embedded_data = None
error_text = None
screenshot_data = None
screenshot_data = []
url = self._get_url(user_friendly=True)
if (
feature_flag_manager.is_feature_enabled("ALERTS_ATTACH_REPORTS")
or self._report_schedule.type == ReportScheduleType.REPORT
):
if self._report_schedule.report_format == ReportDataFormat.VISUALIZATION:
screenshot_data = self._get_screenshot()
screenshot_data = self._get_screenshots()
if not screenshot_data:
error_text = "Unexpected missing screenshot"
elif (
@ -346,7 +360,7 @@ class BaseReportState:
return NotificationContent(
name=name,
url=url,
screenshot=screenshot_data,
screenshots=screenshot_data,
description=self._report_schedule.description,
csv=csv_data,
embedded_data=embedded_data,

View File

@ -27,7 +27,7 @@ from superset.models.reports import ReportRecipients, ReportRecipientType
class NotificationContent:
name: str
csv: Optional[bytes] = None # bytes for csv file
screenshot: Optional[bytes] = None # bytes for the screenshot
screenshots: Optional[List[bytes]] = None # bytes for a list of screenshots
text: Optional[str] = None
description: Optional[str] = ""
url: Optional[str] = None # url to chart/dashboard for this screenshot

View File

@ -69,10 +69,15 @@ class EmailNotification(BaseNotification): # pylint: disable=too-few-public-met
return EmailContent(body=self._error_template(self._content.text))
# Get the domain from the 'From' address ..
# and make a message id without the < > in the end
image = None
csv_data = None
domain = self._get_smtp_domain()
msgid = make_msgid(domain)[1:-1]
images = {}
if self._content.screenshots:
images = {
make_msgid(domain)[1:-1]: screenshot
for screenshot in self._content.screenshots
}
# Strip any malicious HTML from the description
description = bleach.clean(self._content.description or "")
@ -89,11 +94,16 @@ class EmailNotification(BaseNotification): # pylint: disable=too-few-public-met
html_table = ""
call_to_action = __("Explore in Superset")
img_tag = (
f'<img width="1000px" src="cid:{msgid}">'
if self._content.screenshot
else ""
)
img_tags = []
for msgid in images.keys():
img_tags.append(
f"""<div class="image">
<img width="1000px" src="cid:{msgid}">
</div>
<
"""
)
img_tag = "".join(img_tags)
body = textwrap.dedent(
f"""
<html>
@ -105,6 +115,9 @@ class EmailNotification(BaseNotification): # pylint: disable=too-few-public-met
color: rgb(42, 63, 95);
padding: 4px 8px;
}}
.image{{
margin-bottom: 18px;
}}
</style>
</head>
<body>
@ -116,11 +129,10 @@ class EmailNotification(BaseNotification): # pylint: disable=too-few-public-met
</html>
"""
)
if self._content.screenshot:
image = {msgid: self._content.screenshot}
if self._content.csv:
csv_data = {__("%(name)s.csv", name=self._content.name): self._content.csv}
return EmailContent(body=body, images=image, data=csv_data)
return EmailContent(body=body, images=images, data=csv_data)
def _get_subject(self) -> str:
return __(

View File

@ -18,7 +18,7 @@
import json
import logging
from io import IOBase
from typing import Optional, Union
from typing import Sequence, Union
import backoff
from flask_babel import gettext as __
@ -133,16 +133,16 @@ Error: %(text)s
return self._message_template(table)
def _get_inline_file(self) -> Optional[Union[str, IOBase, bytes]]:
def _get_inline_files(self) -> Sequence[Union[str, IOBase, bytes]]:
if self._content.csv:
return self._content.csv
if self._content.screenshot:
return self._content.screenshot
return None
return [self._content.csv]
if self._content.screenshots:
return self._content.screenshots
return []
@backoff.on_exception(backoff.expo, SlackApiError, factor=10, base=2, max_tries=5)
def send(self) -> None:
file = self._get_inline_file()
files = self._get_inline_files()
title = self._content.name
channel = self._get_channel()
body = self._get_body()
@ -153,14 +153,15 @@ Error: %(text)s
token = token()
client = WebClient(token=token, proxy=app.config["SLACK_PROXY"])
# files_upload returns SlackResponse as we run it in sync mode.
if file:
client.files_upload(
channels=channel,
file=file,
initial_comment=body,
title=title,
filetype=file_type,
)
if files:
for file in files:
client.files_upload(
channels=channel,
file=file,
initial_comment=body,
title=title,
filetype=file_type,
)
else:
client.chat_postMessage(channel=channel, text=body)
logger.info("Report sent to slack")

View File

@ -170,6 +170,7 @@ class ReportSchedulePostSchema(Schema):
description=creation_method_description,
)
dashboard = fields.Integer(required=False, allow_none=True)
selected_tabs = fields.List(fields.Integer(), required=False, allow_none=True)
database = fields.Integer(required=False)
owners = fields.List(fields.Integer(description=owners_description))
validator_type = fields.String(
@ -202,6 +203,7 @@ class ReportSchedulePostSchema(Schema):
default=ReportDataFormat.VISUALIZATION,
validate=validate.OneOf(choices=tuple(key.value for key in ReportDataFormat)),
)
extra = fields.Dict(default=None,)
force_screenshot = fields.Boolean(default=False)
@validates_schema

View File

@ -80,7 +80,6 @@ def create_dashboard(
slug: str, title: str, position: str, slices: List[Slice]
) -> Dashboard:
dash = db.session.query(Dashboard).filter_by(slug=slug).one_or_none()
if not dash:
dash = Dashboard()
dash.dashboard_title = title

View File

@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import pytest
from superset import db
from superset.models.dashboard import Dashboard
from tests.integration_tests.dashboard_utils import create_dashboard
from tests.integration_tests.test_app import app
@pytest.fixture(scope="session")
def tabbed_dashboard():
position_json = {
"DASHBOARD_VERSION_KEY": "v2",
"GRID_ID": {
"children": ["TABS-IpViLohnyP"],
"id": "GRID_ID",
"parents": ["ROOT_ID"],
"type": "GRID",
},
"HEADER_ID": {
"id": "HEADER_ID",
"meta": {"text": "tabbed dashboard"},
"type": "HEADER",
},
"ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"},
"TAB-j53G4gtKGF": {
"children": [],
"id": "TAB-j53G4gtKGF",
"meta": {
"defaultText": "Tab title",
"placeholder": "Tab title",
"text": "Tab 1",
},
"parents": ["ROOT_ID", "GRID_ID", "TABS-IpViLohnyP"],
"type": "TAB",
},
"TAB-nerWR09Ju": {
"children": [],
"id": "TAB-nerWR09Ju",
"meta": {
"defaultText": "Tab title",
"placeholder": "Tab title",
"text": "Tab 2",
},
"parents": ["ROOT_ID", "GRID_ID", "TABS-IpViLohnyP"],
"type": "TAB",
},
"TABS-IpViLohnyP": {
"children": ["TAB-j53G4gtKGF", "TAB-nerWR09Ju"],
"id": "TABS-IpViLohnyP",
"meta": {},
"parents": ["ROOT_ID", "GRID_ID"],
"type": "TABS",
},
}
with app.app_context():
dash = create_dashboard(
"tabbed-dash-test", "Tabbed Dash Test", json.dumps(position_json), []
)
yield dash

View File

@ -45,9 +45,9 @@ from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
from tests.integration_tests.fixtures.tabbed_dashboard import tabbed_dashboard
from tests.integration_tests.reports.utils import insert_report_schedule
REPORTS_COUNT = 10
@ -1526,3 +1526,80 @@ class TestReportSchedulesApi(SupersetTestCase):
assert rv.status_code == 405
rv = self.client.delete(uri)
assert rv.status_code == 405
@pytest.mark.usefixtures("create_report_schedules")
@pytest.mark.usefixtures("tabbed_dashboard")
def test_when_invalid_tab_ids_are_given_it_raises_bad_request(self):
"""
when tab ids are specified in the extra argument, make sure that the
tab ids are valid.
"""
self.login(username="admin")
dashboard = (
db.session.query(Dashboard)
.filter(Dashboard.slug == "tabbed-dash-test")
.first()
)
example_db = get_example_database()
report_schedule_data = {
"type": ReportScheduleType.ALERT,
"name": "new3",
"description": "description",
"crontab": "0 9 * * *",
"creation_method": ReportCreationMethodType.ALERTS_REPORTS,
"recipients": [
{
"type": ReportRecipientType.EMAIL,
"recipient_config_json": {"target": "target@superset.org"},
},
],
"grace_period": 14400,
"working_timeout": 3600,
"chart": None,
"dashboard": dashboard.id,
"database": example_db.id,
"extra": {"dashboard_tab_ids": ["INVALID-TAB-ID-1", "TABS-IpViLohnyP"]},
}
response = self.client.post("api/v1/report/", json=report_schedule_data)
assert response.status_code == 422
assert response.json == {
"message": {"extra": ["Invalid tab IDs selected: ['INVALID-TAB-ID-1']"]}
}
@pytest.mark.usefixtures("create_report_schedules")
@pytest.mark.usefixtures("tabbed_dashboard")
def test_when_tab_ids_are_given_it_gets_added_to_extra(self):
self.login(username="admin")
dashboard = (
db.session.query(Dashboard)
.filter(Dashboard.slug == "tabbed-dash-test")
.first()
)
example_db = get_example_database()
report_schedule_data = {
"type": ReportScheduleType.ALERT,
"name": "new3",
"description": "description",
"crontab": "0 9 * * *",
"creation_method": ReportCreationMethodType.ALERTS_REPORTS,
"recipients": [
{
"type": ReportRecipientType.EMAIL,
"recipient_config_json": {"target": "target@superset.org"},
},
],
"grace_period": 14400,
"working_timeout": 3600,
"chart": None,
"dashboard": dashboard.id,
"database": example_db.id,
"extra": {"dashboard_tab_ids": ["TABS-IpViLohnyP"]},
}
response = self.client.post("api/v1/report/", json=report_schedule_data)
assert response.status_code == 201
assert json.loads(
db.session.query(ReportSchedule)
.filter(ReportSchedule.id == response.json["id"])
.first()
.extra
) == {"dashboard_tab_ids": ["TABS-IpViLohnyP"]}

View File

@ -17,7 +17,7 @@
import json
from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import List, Optional
from typing import Any, Dict, List, Optional
from unittest.mock import Mock, patch
from uuid import uuid4
@ -61,6 +61,7 @@ from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
from tests.integration_tests.fixtures.tabbed_dashboard import tabbed_dashboard
from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices_module_scope,
load_world_bank_data,
@ -135,6 +136,7 @@ def create_report_notification(
grace_period: Optional[int] = None,
report_format: Optional[ReportDataFormat] = None,
name: Optional[str] = None,
extra: Optional[Dict[str, Any]] = None,
force_screenshot: bool = False,
) -> ReportSchedule:
report_type = report_type or ReportScheduleType.REPORT
@ -175,6 +177,7 @@ def create_report_notification(
validator_config_json=validator_config_json,
grace_period=grace_period,
report_format=report_format or ReportDataFormat.VISUALIZATION,
extra=extra,
force_screenshot=force_screenshot,
)
return report_schedule
@ -287,6 +290,18 @@ def create_report_email_dashboard():
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_email_tabbed_dashboard(tabbed_dashboard):
with app.app_context():
report_schedule = create_report_notification(
email_target="target@email.com",
dashboard=tabbed_dashboard,
extra={"dashboard_tab_ids": ["TAB-j53G4gtKGF", "TAB-nerWR09Ju",]},
)
yield report_schedule
cleanup_report_schedule(report_schedule)
@pytest.fixture()
def create_report_slack_chart():
with app.app_context():
@ -1314,7 +1329,7 @@ def test_slack_chart_alert_no_attachment(email_mock, create_alert_email_chart):
# Assert the email smtp address
assert email_mock.call_args[0][0] == notification_targets[0]
# Assert the there is no attached image
assert email_mock.call_args[1]["images"] is None
assert email_mock.call_args[1]["images"] == {}
# Assert logs are correct
assert_log(ReportState.SUCCESS)
@ -1553,9 +1568,7 @@ def test_fail_csv(
TEST_ID, create_report_email_chart_with_csv.id, datetime.utcnow()
).run()
notification_targets = get_target_from_report_schedule(
create_report_email_chart_with_csv
)
get_target_from_report_schedule(create_report_email_chart_with_csv)
# Assert the email smtp address, asserts a notification was sent with the error
assert email_mock.call_args[0][0] == OWNER_EMAIL
@ -1585,7 +1598,7 @@ def test_email_disable_screenshot(email_mock, create_alert_email_chart):
# Assert the email smtp address, asserts a notification was sent with the error
assert email_mock.call_args[0][0] == notification_targets[0]
# Assert the there is no attached image
assert email_mock.call_args[1]["images"] is None
assert email_mock.call_args[1]["images"] == {}
assert_log(ReportState.SUCCESS)
@ -1733,3 +1746,29 @@ def test_prune_log_soft_time_out(bulk_delete_logs, create_report_email_dashboard
with pytest.raises(SoftTimeLimitExceeded) as excinfo:
AsyncPruneReportScheduleLogCommand().run()
assert str(excinfo.value) == "SoftTimeLimitExceeded()"
@pytest.mark.usefixtures("create_report_email_tabbed_dashboard",)
@patch("superset.reports.notifications.email.send_email_smtp")
@patch("superset.reports.commands.execute.DashboardScreenshot",)
def test_when_tabs_are_selected_it_takes_screenshots_for_every_tabs(
dashboard_screenshot_mock,
send_email_smtp_mock,
create_report_email_tabbed_dashboard,
):
dashboard_screenshot_mock.get_screenshot.return_value = b"test-image"
dashboard = create_report_email_tabbed_dashboard.dashboard
AsyncExecuteReportScheduleCommand(
TEST_ID, create_report_email_tabbed_dashboard.id, datetime.utcnow()
).run()
tabs = json.loads(create_report_email_tabbed_dashboard.extra)["dashboard_tab_ids"]
assert dashboard_screenshot_mock.call_count == 2
for index, tab in enumerate(tabs):
assert dashboard_screenshot_mock.call_args_list[index].args == (
f"http://0.0.0.0:8080/superset/dashboard/{dashboard.id}/?standalone=3#{tab}",
f"{dashboard.digest}",
)
assert send_email_smtp_mock.called is True
assert len(send_email_smtp_mock.call_args.kwargs["images"]) == 2

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import List, Optional
from typing import Any, Dict, List, Optional
from flask_appbuilder.security.sqla.models import User
@ -51,6 +51,7 @@ def insert_report_schedule(
recipients: Optional[List[ReportRecipients]] = None,
report_format: Optional[ReportDataFormat] = None,
logs: Optional[List[ReportExecutionLog]] = None,
extra: Optional[Dict[Any, Any]] = None,
force_screenshot: bool = False,
) -> ReportSchedule:
owners = owners or []
@ -76,6 +77,7 @@ def insert_report_schedule(
logs=logs,
last_state=last_state,
report_format=report_format,
extra=extra,
force_screenshot=force_screenshot,
)
db.session.add(report_schedule)