diff --git a/superset/reports/notifications/slack.py b/superset/reports/notifications/slack.py index 9ecf8de6d..6a6bbda7b 100644 --- a/superset/reports/notifications/slack.py +++ b/superset/reports/notifications/slack.py @@ -17,12 +17,14 @@ import logging from collections.abc import Sequence from io import IOBase -from typing import Union +from typing import List, Union import backoff import pandas as pd +from deprecation import deprecated from flask import g from flask_babel import gettext as __ +from slack_sdk import WebClient from slack_sdk.errors import ( BotUserAccessError, SlackApiError, @@ -60,16 +62,25 @@ class SlackNotification(BaseNotification): # pylint: disable=too-few-public-met type = ReportRecipientType.SLACK - def _get_channel(self) -> str: + def _get_channels(self, client: WebClient) -> List[str]: """ Get the recipient's channel(s). - Note Slack SDK uses "channel" to refer to one or more - channels. Multiple channels are demarcated by a comma. - :returns: The comma separated list of channel(s) + :returns: A list of channel ids: "EID676L" + :raises SlackApiError: If the API call fails """ recipient_str = json.loads(self._recipient.recipient_config_json)["target"] - return ",".join(get_email_address_list(recipient_str)) + channel_recipients: List[str] = get_email_address_list(recipient_str) + + conversations_list_response = client.conversations_list( + types="public_channel,private_channel" + ) + + return [ + c["id"] + for c in conversations_list_response["channels"] + if c["name"] in channel_recipients + ] def _message_template(self, table: str = "") -> str: return __( @@ -115,15 +126,19 @@ Error: %(text)s # Flatten columns/index so they show up nicely in the table df.columns = [ - " ".join(str(name) for name in column).strip() - if isinstance(column, tuple) - else column + ( + " ".join(str(name) for name in column).strip() + if isinstance(column, tuple) + else column + ) for column in df.columns ] df.index = [ - " ".join(str(name) for name in index).strip() - if isinstance(index, tuple) - else index + ( + " ".join(str(name) for name in index).strip() + if isinstance(index, tuple) + else index + ) for index in df.index ] @@ -162,29 +177,40 @@ Error: %(text)s def _get_inline_files( self, - ) -> tuple[Union[str, None], Sequence[Union[str, IOBase, bytes]]]: + ) -> Sequence[Union[str, IOBase, bytes]]: if self._content.csv: - return ("csv", [self._content.csv]) + return [self._content.csv] if self._content.screenshots: - return ("png", self._content.screenshots) + return self._content.screenshots if self._content.pdf: - return ("pdf", [self._content.pdf]) - return (None, []) + return [self._content.pdf] + return [] - @backoff.on_exception(backoff.expo, SlackApiError, factor=10, base=2, max_tries=5) - @statsd_gauge("reports.slack.send") - def send(self) -> None: - file_type, files = self._get_inline_files() - title = self._content.name - channel = self._get_channel() - body = self._get_body() - global_logs_context = getattr(g, "logs_context", {}) or {} - try: - client = get_slack_client() - # files_upload returns SlackResponse as we run it in sync mode. - if files: + @deprecated(deprecated_in="4.1") + def _deprecated_upload_files( + self, client: WebClient, title: str, body: str + ) -> None: + """ + Deprecated method to upload files to slack + Should only be used if the new method fails + To be removed in the next major release + """ + file_type, files = (None, []) + if self._content.csv: + file_type, files = ("csv", [self._content.csv]) + if self._content.screenshots: + file_type, files = ("png", self._content.screenshots) + if self._content.pdf: + file_type, files = ("pdf", [self._content.pdf]) + + recipient_str = json.loads(self._recipient.recipient_config_json)["target"] + + recipients = get_email_address_list(recipient_str) + + for channel in recipients: + if len(files) > 0: for file in files: - client.files_upload_v2( + client.files_upload( channels=channel, file=file, initial_comment=body, @@ -193,6 +219,46 @@ Error: %(text)s ) else: client.chat_postMessage(channel=channel, text=body) + + @backoff.on_exception(backoff.expo, SlackApiError, factor=10, base=2, max_tries=5) + @statsd_gauge("reports.slack.send") + def send(self) -> None: + global_logs_context = getattr(g, "logs_context", {}) or {} + try: + client = get_slack_client() + title = self._content.name + body = self._get_body() + + try: + channels = self._get_channels(client) + except SlackApiError: + logger.warning( + "Slack scope missing. Using deprecated API to get channels. Please update your Slack app to use the new API.", + extra={ + "execution_id": global_logs_context.get("execution_id"), + }, + ) + self._deprecated_upload_files(client, title, body) + return + + if channels == []: + raise NotificationParamException("No valid channel found") + + files = self._get_inline_files() + + # files_upload returns SlackResponse as we run it in sync mode. + for channel in channels: + if len(files) > 0: + for file in files: + client.files_upload_v2( + channel=channel, + file=file, + initial_comment=body, + title=title, + ) + else: + client.chat_postMessage(channel=channel, text=body) + logger.info( "Report sent to slack", extra={ diff --git a/superset/tasks/slack_util.py b/superset/tasks/slack_util.py deleted file mode 100644 index 5226f50dc..000000000 --- a/superset/tasks/slack_util.py +++ /dev/null @@ -1,62 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -DEPRECATION NOTICE: this module is deprecated and will be removed on 2.0. -""" - -import logging -from io import IOBase -from typing import cast, Optional, Union - -import backoff -from flask import current_app -from slack_sdk import WebClient -from slack_sdk.errors import SlackApiError -from slack_sdk.web.slack_response import SlackResponse - -# Globals -logger = logging.getLogger("tasks.slack_util") - - -@backoff.on_exception(backoff.expo, SlackApiError, factor=10, base=2, max_tries=5) -def deliver_slack_msg( - slack_channel: str, - subject: str, - body: str, - file: Optional[Union[str, IOBase, bytes]], -) -> None: - config = current_app.config - token = config["SLACK_API_TOKEN"] - if callable(token): - token = token() - client = WebClient(token=token, proxy=config["SLACK_PROXY"]) - # files_upload returns SlackResponse as we run it in sync mode. - if file: - response = cast( - SlackResponse, - client.files_upload_v2( - channels=slack_channel, file=file, initial_comment=body, title=subject - ), - ) - assert response["file"], str(response) # the uploaded file - else: - response = cast( - SlackResponse, - client.chat_postMessage(channel=slack_channel, text=body), - ) - assert response["message"]["text"], str(response) - logger.info("Sent the report to the slack %s", slack_channel) diff --git a/superset/utils/json.py b/superset/utils/json.py index e3a530525..0d7e31b9c 100644 --- a/superset/utils/json.py +++ b/superset/utils/json.py @@ -24,7 +24,7 @@ import numpy as np import pandas as pd import simplejson from flask_babel.speaklater import LazyString -from simplejson import JSONDecodeError # noqa: F401 # pylint: disable=unused-import +from simplejson import JSONDecodeError from superset.utils.dates import datetime_to_epoch, EPOCH diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index 34c718c20..6e2b40860 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -28,6 +28,7 @@ from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.reports.models import ReportSchedule from superset.utils import json from superset.utils.core import get_example_default_schema from superset.utils.database import get_example_database @@ -81,6 +82,7 @@ def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data): with app.app_context(): dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data() yield + _cleanup_reports(dash_id_to_delete, slices_ids_to_delete) _cleanup(dash_id_to_delete, slices_ids_to_delete) @@ -143,6 +145,21 @@ def _cleanup(dash_id: int, slices_ids: list[int]) -> None: db.session.commit() +def _cleanup_reports(dash_id: int, slices_ids: list[int]) -> None: + reports_with_dash = ( + db.session.query(ReportSchedule).filter_by(dashboard_id=dash_id).all() + ) + reports_with_slices = ( + db.session.query(ReportSchedule) + .filter(ReportSchedule.chart_id.in_(slices_ids)) + .all() + ) + + for report in reports_with_dash + reports_with_slices: + db.session.delete(report) + db.session.commit() + + def _get_dataframe(database: Database) -> DataFrame: data = _get_world_bank_data() df = pd.DataFrame.from_dict(data) diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index e57912759..e4c64a51b 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -676,7 +676,9 @@ def test_email_chart_report_schedule_alpha_owner( with freeze_time("2020-01-01T00:00:00Z"): AsyncExecuteReportScheduleCommand( - TEST_ID, create_report_email_chart_alpha_owner.id, datetime.utcnow() + TEST_ID, + create_report_email_chart_alpha_owner.id, + datetime.utcnow(), ).run() notification_targets = get_target_from_report_schedule( @@ -724,7 +726,9 @@ def test_email_chart_report_schedule_force_screenshot( with freeze_time("2020-01-01T00:00:00Z"): AsyncExecuteReportScheduleCommand( - TEST_ID, create_report_email_chart_force_screenshot.id, datetime.utcnow() + TEST_ID, + create_report_email_chart_force_screenshot.id, + datetime.utcnow(), ).run() notification_targets = get_target_from_report_schedule( @@ -1098,11 +1102,11 @@ def test_email_dashboard_report_schedule_force_screenshot( @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_slack_chart" ) -@patch("superset.utils.slack.WebClient.files_upload_v2") +@patch("superset.reports.notifications.slack.get_slack_client") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") def test_slack_chart_report_schedule( screenshot_mock, - file_upload_mock, + slack_client_mock, create_report_slack_chart, ): """ @@ -1110,6 +1114,13 @@ def test_slack_chart_report_schedule( """ # setup screenshot mock screenshot_mock.return_value = SCREENSHOT_FILE + notification_targets = get_target_from_report_schedule(create_report_slack_chart) + + channel_name = notification_targets[0] + channel_id = "channel_id" + slack_client_mock.return_value.conversations_list.return_value = { + "channels": [{"id": channel_id, "name": channel_name}] + } with freeze_time("2020-01-01T00:00:00Z"): with patch.object(current_app.config["STATS_LOGGER"], "gauge") as statsd_mock: @@ -1117,12 +1128,57 @@ def test_slack_chart_report_schedule( TEST_ID, create_report_slack_chart.id, datetime.utcnow() ).run() - notification_targets = get_target_from_report_schedule( - create_report_slack_chart + assert ( + slack_client_mock.return_value.files_upload_v2.call_args[1]["channel"] + == channel_id + ) + assert ( + slack_client_mock.return_value.files_upload_v2.call_args[1]["file"] + == SCREENSHOT_FILE ) - assert file_upload_mock.call_args[1]["channels"] == notification_targets[0] - assert file_upload_mock.call_args[1]["file"] == SCREENSHOT_FILE + # Assert logs are correct + assert_log(ReportState.SUCCESS) + statsd_mock.assert_called_once_with("reports.slack.send.ok", 1) + + +@pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", "create_report_slack_chart" +) +@patch("superset.reports.notifications.slack.get_slack_client") +@patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") +def test_slack_chart_report_schedule_deprecated( + screenshot_mock, + slack_client_mock, + create_report_slack_chart, +): + """ + ExecuteReport Command: Test chart slack report schedule + """ + # setup screenshot mock + screenshot_mock.return_value = SCREENSHOT_FILE + notification_targets = get_target_from_report_schedule(create_report_slack_chart) + + channel_name = notification_targets[0] + + slack_client_mock.return_value.conversations_list.side_effect = SlackApiError( + "Error", "Response" + ) + + with freeze_time("2020-01-01T00:00:00Z"): + with patch.object(current_app.config["STATS_LOGGER"], "gauge") as statsd_mock: + AsyncExecuteReportScheduleCommand( + TEST_ID, create_report_slack_chart.id, datetime.utcnow() + ).run() + + assert ( + slack_client_mock.return_value.files_upload.call_args[1]["channels"] + == channel_name + ) + assert ( + slack_client_mock.return_value.files_upload.call_args[1]["file"] + == SCREENSHOT_FILE + ) # Assert logs are correct assert_log(ReportState.SUCCESS) @@ -1186,7 +1242,7 @@ def test_slack_chart_report_schedule_with_errors( @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_slack_chart_with_csv" ) -@patch("superset.utils.slack.WebClient.files_upload_v2") +@patch("superset.reports.notifications.slack.get_slack_client") @patch("superset.utils.csv.urllib.request.urlopen") @patch("superset.utils.csv.urllib.request.OpenerDirector.open") @patch("superset.utils.csv.get_chart_csv_data") @@ -1194,7 +1250,7 @@ def test_slack_chart_report_schedule_with_csv( csv_mock, mock_open, mock_urlopen, - file_upload_mock, + slack_client_mock_class, create_report_slack_chart_with_csv, ): """ @@ -1207,16 +1263,82 @@ def test_slack_chart_report_schedule_with_csv( mock_urlopen.return_value.getcode.return_value = 200 response.read.return_value = CSV_FILE + notification_targets = get_target_from_report_schedule( + create_report_slack_chart_with_csv + ) + + channel_name = notification_targets[0] + channel_id = "channel_id" + slack_client_mock_class.return_value = Mock() + slack_client_mock_class.return_value.conversations_list.return_value = { + "channels": [{"id": channel_id, "name": channel_name}] + } + with freeze_time("2020-01-01T00:00:00Z"): AsyncExecuteReportScheduleCommand( TEST_ID, create_report_slack_chart_with_csv.id, datetime.utcnow() ).run() - notification_targets = get_target_from_report_schedule( - create_report_slack_chart_with_csv + assert ( + slack_client_mock_class.return_value.files_upload_v2.call_args[1]["channel"] + == channel_id + ) + assert ( + slack_client_mock_class.return_value.files_upload_v2.call_args[1]["file"] + == CSV_FILE + ) + + # Assert logs are correct + assert_log(ReportState.SUCCESS) + + +@pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", "create_report_slack_chart_with_csv" +) +@patch("superset.reports.notifications.slack.get_slack_client") +@patch("superset.utils.csv.urllib.request.urlopen") +@patch("superset.utils.csv.urllib.request.OpenerDirector.open") +@patch("superset.utils.csv.get_chart_csv_data") +def test_slack_chart_report_schedule_with_csv_deprecated_api( + csv_mock, + mock_open, + mock_urlopen, + slack_client_mock_class, + create_report_slack_chart_with_csv, +): + """ + ExecuteReport Command: Test chart slack report schedule with CSV + """ + # setup csv mock + response = Mock() + mock_open.return_value = response + mock_urlopen.return_value = response + mock_urlopen.return_value.getcode.return_value = 200 + response.read.return_value = CSV_FILE + + notification_targets = get_target_from_report_schedule( + create_report_slack_chart_with_csv + ) + + channel_name = notification_targets[0] + slack_client_mock_class.return_value = Mock() + slack_client_mock_class.return_value.conversations_list.side_effect = SlackApiError( + "Error", "Response" + ) + + with freeze_time("2020-01-01T00:00:00Z"): + AsyncExecuteReportScheduleCommand( + TEST_ID, create_report_slack_chart_with_csv.id, datetime.utcnow() + ).run() + + assert ( + slack_client_mock_class.return_value.files_upload.call_args[1]["channels"] + == channel_name + ) + assert ( + slack_client_mock_class.return_value.files_upload.call_args[1]["file"] + == CSV_FILE ) - assert file_upload_mock.call_args[1]["channels"] == notification_targets[0] - assert file_upload_mock.call_args[1]["file"] == CSV_FILE # Assert logs are correct assert_log(ReportState.SUCCESS) @@ -1225,15 +1347,15 @@ def test_slack_chart_report_schedule_with_csv( @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_slack_chart_with_text" ) -@patch("superset.utils.slack.WebClient.chat_postMessage") @patch("superset.utils.csv.urllib.request.urlopen") @patch("superset.utils.csv.urllib.request.OpenerDirector.open") +@patch("superset.reports.notifications.slack.get_slack_client") @patch("superset.utils.csv.get_chart_dataframe") def test_slack_chart_report_schedule_with_text( dataframe_mock, + slack_client_mock_class, mock_open, mock_urlopen, - post_message_mock, create_report_slack_chart_with_text, ): """ @@ -1255,11 +1377,23 @@ def test_slack_chart_report_schedule_with_text( }, "colnames": [("t1",), ("t2",), ("t3__sum",)], "indexnames": [(0,), (1,)], + "coltypes": [1, 1, 0], }, ], } ).encode("utf-8") + notification_targets = get_target_from_report_schedule( + create_report_slack_chart_with_text + ) + + channel_name = notification_targets[0] + channel_id = "channel_id" + + slack_client_mock_class.return_value.conversations_list.return_value = { + "channels": [{"id": channel_id, "name": channel_name}] + } + with freeze_time("2020-01-01T00:00:00Z"): AsyncExecuteReportScheduleCommand( TEST_ID, create_report_slack_chart_with_text.id, datetime.utcnow() @@ -1269,10 +1403,98 @@ def test_slack_chart_report_schedule_with_text( |---:|:-----|:-----|:----------| | 0 | c11 | c12 | c13 | | 1 | c21 | c22 | c23 |""" - assert table_markdown in post_message_mock.call_args[1]["text"] + assert ( + table_markdown + in slack_client_mock_class.return_value.chat_postMessage.call_args[1][ + "text" + ] + ) assert ( f"" - in post_message_mock.call_args[1]["text"] + in slack_client_mock_class.return_value.chat_postMessage.call_args[1][ + "text" + ] + ) + + # Assert logs are correct + assert_log(ReportState.SUCCESS) + + +@pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", "create_report_slack_chart_with_text" +) +@patch("superset.utils.csv.urllib.request.urlopen") +@patch("superset.utils.csv.urllib.request.OpenerDirector.open") +@patch("superset.reports.notifications.slack.get_slack_client") +@patch("superset.utils.csv.get_chart_dataframe") +def test_slack_chart_report_schedule_with_text_deprecated_slack_api( + dataframe_mock, + slack_client_mock_class, + mock_open, + mock_urlopen, + create_report_slack_chart_with_text, +): + """ + ExecuteReport Command: Test chart slack report schedule with text + """ + # setup dataframe mock + response = Mock() + mock_open.return_value = response + mock_urlopen.return_value = response + mock_urlopen.return_value.getcode.return_value = 200 + response.read.return_value = json.dumps( + { + "result": [ + { + "data": { + "t1": {0: "c11", 1: "c21"}, + "t2": {0: "c12", 1: "c22"}, + "t3__sum": {0: "c13", 1: "c23"}, + }, + "colnames": [("t1",), ("t2",), ("t3__sum",)], + "indexnames": [(0,), (1,)], + "coltypes": [1, 1, 0], + }, + ], + } + ).encode("utf-8") + + notification_targets = get_target_from_report_schedule( + create_report_slack_chart_with_text + ) + + channel_name = notification_targets[0] + + slack_client_mock_class.return_value.conversations_list.side_effect = SlackApiError( + "Error", "Response" + ) + + with freeze_time("2020-01-01T00:00:00Z"): + AsyncExecuteReportScheduleCommand( + TEST_ID, create_report_slack_chart_with_text.id, datetime.utcnow() + ).run() + + table_markdown = """| | t1 | t2 | t3__sum | +|---:|:-----|:-----|:----------| +| 0 | c11 | c12 | c13 | +| 1 | c21 | c22 | c23 |""" + assert ( + table_markdown + in slack_client_mock_class.return_value.chat_postMessage.call_args[1][ + "text" + ] + ) + assert ( + f"" + in slack_client_mock_class.return_value.chat_postMessage.call_args[1][ + "text" + ] + ) + assert ( + slack_client_mock_class.return_value.chat_postMessage.call_args[1][ + "channel" + ] + == channel_name ) # Assert logs are correct @@ -1298,7 +1520,9 @@ def test_report_schedule_working(create_report_slack_chart_working): with freeze_time("2020-01-01T00:00:00Z"): with pytest.raises(ReportSchedulePreviousWorkingError): AsyncExecuteReportScheduleCommand( - TEST_ID, create_report_slack_chart_working.id, datetime.utcnow() + TEST_ID, + create_report_slack_chart_working.id, + datetime.utcnow(), ).run() assert_log( @@ -1319,7 +1543,9 @@ def test_report_schedule_working_timeout(create_report_slack_chart_working): with freeze_time(current_time): with pytest.raises(ReportScheduleWorkingTimeoutError): AsyncExecuteReportScheduleCommand( - TEST_ID, create_report_slack_chart_working.id, datetime.utcnow() + TEST_ID, + create_report_slack_chart_working.id, + datetime.utcnow(), ).run() # Only needed for MySQL, understand why @@ -1353,10 +1579,14 @@ def test_report_schedule_success_grace(create_alert_slack_chart_success): @pytest.mark.usefixtures("create_alert_slack_chart_grace") -@patch("superset.utils.slack.WebClient.files_upload_v2") +@patch("superset.utils.slack.WebClient.files_upload") @patch("superset.utils.screenshots.ChartScreenshot.get_screenshot") +@patch("superset.reports.notifications.slack.get_slack_client") def test_report_schedule_success_grace_end( - screenshot_mock, file_upload_mock, create_alert_slack_chart_grace + slack_client_mock_class, + screenshot_mock, + file_upload_mock, + create_alert_slack_chart_grace, ): """ ExecuteReport Command: Test report schedule on grace to noop @@ -1369,6 +1599,17 @@ def test_report_schedule_success_grace_end( seconds=create_alert_slack_chart_grace.grace_period + 1 ) + notification_targets = get_target_from_report_schedule( + create_alert_slack_chart_grace + ) + + channel_name = notification_targets[0] + channel_id = "channel_id" + + slack_client_mock_class.return_value.conversations_list.return_value = { + "channels": [{"id": channel_id, "name": channel_name}] + } + with freeze_time(current_time): AsyncExecuteReportScheduleCommand( TEST_ID, create_alert_slack_chart_grace.id, datetime.utcnow() @@ -1533,7 +1774,15 @@ def test_slack_token_callable_chart_report( """ ExecuteReport Command: Test chart slack alert (slack token callable) """ + notification_targets = get_target_from_report_schedule(create_report_slack_chart) + + channel_name = notification_targets[0] + channel_id = "channel_id" slack_client_mock_class.return_value = Mock() + slack_client_mock_class.return_value.conversations_list.return_value = { + "channels": [{"id": channel_id, "name": channel_name}] + } + app.config["SLACK_API_TOKEN"] = Mock(return_value="cool_code") # setup screenshot mock screenshot_mock.return_value = SCREENSHOT_FILE @@ -1542,7 +1791,7 @@ def test_slack_token_callable_chart_report( AsyncExecuteReportScheduleCommand( TEST_ID, create_report_slack_chart.id, datetime.utcnow() ).run() - app.config["SLACK_API_TOKEN"].assert_called_once() + app.config["SLACK_API_TOKEN"].assert_called() assert slack_client_mock_class.called_with(token="cool_code", proxy="") assert_log(ReportState.SUCCESS) @@ -1661,9 +1910,7 @@ def test_soft_timeout_csv( TEST_ID, create_report_email_chart_with_csv.id, datetime.utcnow() ).run() - get_target_from_report_schedule( # noqa: F841 - create_report_email_chart_with_csv - ) + get_target_from_report_schedule(create_report_email_chart_with_csv) # noqa: F841 # Assert the email smtp address, asserts a notification was sent with the error assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL @@ -1701,9 +1948,7 @@ def test_generate_no_csv( TEST_ID, create_report_email_chart_with_csv.id, datetime.utcnow() ).run() - get_target_from_report_schedule( # noqa: F841 - create_report_email_chart_with_csv - ) + get_target_from_report_schedule(create_report_email_chart_with_csv) # noqa: F841 # Assert the email smtp address, asserts a notification was sent with the error assert email_mock.call_args[0][0] == DEFAULT_OWNER_EMAIL @@ -1808,7 +2053,9 @@ def test_invalid_sql_alert(email_mock, create_invalid_sql_alert_email_chart): with freeze_time("2020-01-01T00:00:00Z"): with pytest.raises((AlertQueryError, AlertQueryInvalidTypeError)): AsyncExecuteReportScheduleCommand( - TEST_ID, create_invalid_sql_alert_email_chart.id, datetime.utcnow() + TEST_ID, + create_invalid_sql_alert_email_chart.id, + datetime.utcnow(), ).run() # Assert the email smtp address, asserts a notification was sent with the error @@ -1824,7 +2071,9 @@ def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart): with freeze_time("2020-01-01T00:00:00Z"): with pytest.raises((AlertQueryError, AlertQueryInvalidTypeError)): AsyncExecuteReportScheduleCommand( - TEST_ID, create_invalid_sql_alert_email_chart.id, datetime.utcnow() + TEST_ID, + create_invalid_sql_alert_email_chart.id, + datetime.utcnow(), ).run() # Only needed for MySQL, understand why @@ -1839,7 +2088,9 @@ def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart): with freeze_time("2020-01-01T00:30:00Z"): with pytest.raises((AlertQueryError, AlertQueryInvalidTypeError)): AsyncExecuteReportScheduleCommand( - TEST_ID, create_invalid_sql_alert_email_chart.id, datetime.utcnow() + TEST_ID, + create_invalid_sql_alert_email_chart.id, + datetime.utcnow(), ).run() db.session.commit() assert ( @@ -1850,7 +2101,9 @@ def test_grace_period_error(email_mock, create_invalid_sql_alert_email_chart): with freeze_time("2020-01-01T01:30:00Z"): with pytest.raises((AlertQueryError, AlertQueryInvalidTypeError)): AsyncExecuteReportScheduleCommand( - TEST_ID, create_invalid_sql_alert_email_chart.id, datetime.utcnow() + TEST_ID, + create_invalid_sql_alert_email_chart.id, + datetime.utcnow(), ).run() db.session.commit() assert ( @@ -1871,7 +2124,9 @@ def test_grace_period_error_flap( with freeze_time("2020-01-01T00:00:00Z"): with pytest.raises((AlertQueryError, AlertQueryInvalidTypeError)): AsyncExecuteReportScheduleCommand( - TEST_ID, create_invalid_sql_alert_email_chart.id, datetime.utcnow() + TEST_ID, + create_invalid_sql_alert_email_chart.id, + datetime.utcnow(), ).run() db.session.commit() # Assert we have 1 notification sent on the log @@ -1882,7 +2137,9 @@ def test_grace_period_error_flap( with freeze_time("2020-01-01T00:30:00Z"): with pytest.raises((AlertQueryError, AlertQueryInvalidTypeError)): AsyncExecuteReportScheduleCommand( - TEST_ID, create_invalid_sql_alert_email_chart.id, datetime.utcnow() + TEST_ID, + create_invalid_sql_alert_email_chart.id, + datetime.utcnow(), ).run() db.session.commit() assert ( @@ -1915,7 +2172,9 @@ def test_grace_period_error_flap( with freeze_time("2020-01-01T00:32:00Z"): with pytest.raises((AlertQueryError, AlertQueryInvalidTypeError)): AsyncExecuteReportScheduleCommand( - TEST_ID, create_invalid_sql_alert_email_chart.id, datetime.utcnow() + TEST_ID, + create_invalid_sql_alert_email_chart.id, + datetime.utcnow(), ).run() db.session.commit() assert ( diff --git a/tests/unit_tests/notifications/slack_tests.py b/tests/unit_tests/notifications/slack_tests.py index 4cadf198f..19cd690b9 100644 --- a/tests/unit_tests/notifications/slack_tests.py +++ b/tests/unit_tests/notifications/slack_tests.py @@ -22,7 +22,9 @@ import pandas as pd @patch("superset.reports.notifications.slack.g") @patch("superset.reports.notifications.slack.logger") +@patch("superset.reports.notifications.slack.get_slack_client") def test_send_slack( + slack_client_mock: MagicMock, logger_mock: MagicMock, flask_global_mock: MagicMock, ) -> None: @@ -31,10 +33,12 @@ def test_send_slack( from superset.reports.models import ReportRecipients, ReportRecipientType from superset.reports.notifications.base import NotificationContent from superset.reports.notifications.slack import SlackNotification - from superset.utils.slack import WebClient execution_id = uuid.uuid4() flask_global_mock.logs_context = {"execution_id": execution_id} + slack_client_mock.return_value.conversations_list.return_value = { + "channels": [{"name": "some_channel", "id": "123"}] + } content = NotificationContent( name="test alert", header_data={ @@ -54,22 +58,20 @@ def test_send_slack( ), description='

This is a test alert


', ) - with patch.object( - WebClient, "chat_postMessage", return_value=True - ) as chat_post_message_mock: - SlackNotification( - recipient=ReportRecipients( - type=ReportRecipientType.SLACK, - recipient_config_json='{"target": "some_channel"}', - ), - content=content, - ).send() - logger_mock.info.assert_called_with( - "Report sent to slack", extra={"execution_id": execution_id} - ) - chat_post_message_mock.assert_called_with( - channel="some_channel", - text="""*test alert* + + SlackNotification( + recipient=ReportRecipients( + type=ReportRecipientType.SLACK, + recipient_config_json='{"target": "some_channel"}', + ), + content=content, + ).send() + logger_mock.info.assert_called_with( + "Report sent to slack", extra={"execution_id": execution_id} + ) + slack_client_mock.return_value.chat_postMessage.assert_called_with( + channel="123", + text="""*test alert*

This is a test alert


@@ -83,4 +85,4 @@ def test_send_slack( | 2 | 3 | 6 | 333 | ``` """, - ) + ) diff --git a/tests/unit_tests/reports/notifications/slack_tests.py b/tests/unit_tests/reports/notifications/slack_tests.py index 0a5e9baa4..2f0586021 100644 --- a/tests/unit_tests/reports/notifications/slack_tests.py +++ b/tests/unit_tests/reports/notifications/slack_tests.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from unittest.mock import Mock + import pandas as pd @@ -53,6 +55,15 @@ def test_get_channel_with_multi_recipients() -> None: content=content, ) - result = slack_notification._get_channel() + client = Mock() + client.conversations_list.return_value = { + "channels": [ + {"name": "some_channel", "id": "23SDKE"}, + {"name": "second_channel", "id": "WD3D8KE"}, + {"name": "third_channel", "id": "223DFKE"}, + ] + } - assert result == "some_channel,second_channel,third_channel" + result = slack_notification._get_channels(client) + + assert result == ["23SDKE", "WD3D8KE", "223DFKE"]