diff --git a/superset/migrations/versions/1226819ee0e3_fix_wrong_constraint_on_table_columns.py b/superset/migrations/versions/1226819ee0e3_fix_wrong_constraint_on_table_columns.py index 3f8daa91e..9f09c8991 100644 --- a/superset/migrations/versions/1226819ee0e3_fix_wrong_constraint_on_table_columns.py +++ b/superset/migrations/versions/1226819ee0e3_fix_wrong_constraint_on_table_columns.py @@ -41,7 +41,7 @@ naming_convention = { def find_constraint_name(upgrade=True): cols = {"column_name"} if upgrade else {"datasource_name"} return generic_find_constraint_name( - table="columns", columns=cols, referenced="datasources", db=db + table="columns", columns=cols, referenced="datasources", database=db ) diff --git a/superset/migrations/versions/3b626e2a6783_sync_db_with_models.py b/superset/migrations/versions/3b626e2a6783_sync_db_with_models.py index b7e55974d..a652439cc 100644 --- a/superset/migrations/versions/3b626e2a6783_sync_db_with_models.py +++ b/superset/migrations/versions/3b626e2a6783_sync_db_with_models.py @@ -45,10 +45,10 @@ def upgrade(): table="slices", columns={"druid_datasource_id"}, referenced="datasources", - db=db, + database=db, ) slices_ibfk_2 = generic_find_constraint_name( - table="slices", columns={"table_id"}, referenced="tables", db=db + table="slices", columns={"table_id"}, referenced="tables", database=db ) with op.batch_alter_table("slices") as batch_op: @@ -119,7 +119,7 @@ def downgrade(): table="columns", columns={"datasource_name"}, referenced="datasources", - db=db, + database=db, ) with op.batch_alter_table("columns") as batch_op: batch_op.drop_constraint(fk_columns, type_="foreignkey") diff --git a/superset/utils/core.py b/superset/utils/core.py index c7fcd4653..c3e3b462a 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=C,R,W """Utility functions used across Superset""" import decimal import errno @@ -175,10 +174,12 @@ class _memoized: """Return the function's docstring.""" return self.func.__doc__ or "" - def __get__(self, obj: Any, objtype: Type[Any]) -> functools.partial: # type: ignore + def __get__( + self, obj: Any, objtype: Type[Any] + ) -> functools.partial: # type: ignore if not self.is_method: self.is_method = True - """Support instance methods.""" + # Support instance methods. return functools.partial(self.__call__, obj) @@ -187,12 +188,11 @@ def memoized( ) -> Callable[..., Any]: if func: return _memoized(func) - else: - def wrapper(f: Callable[..., Any]) -> Callable[..., Any]: - return _memoized(f, watch) + def wrapper(f: Callable[..., Any]) -> Callable[..., Any]: + return _memoized(f, watch) - return wrapper + return wrapper def parse_js_uri_path_item( @@ -247,7 +247,7 @@ def list_minus(l: List[Any], minus: List[Any]) -> List[Any]: return [o for o in l if o not in minus] -def parse_human_datetime(s: str) -> datetime: +def parse_human_datetime(human_readable: str) -> datetime: """ Returns ``datetime.datetime`` from human readable strings @@ -269,23 +269,30 @@ def parse_human_datetime(s: str) -> datetime: True """ try: - dttm = parse(s) - except Exception: + dttm = parse(human_readable) + except Exception: # pylint: disable=broad-except try: cal = parsedatetime.Calendar() - parsed_dttm, parsed_flags = cal.parseDT(s) + parsed_dttm, parsed_flags = cal.parseDT(human_readable) # when time is not extracted, we 'reset to midnight' if parsed_flags & 2 == 0: parsed_dttm = parsed_dttm.replace(hour=0, minute=0, second=0) dttm = dttm_from_timetuple(parsed_dttm.utctimetuple()) except Exception as ex: logger.exception(ex) - raise ValueError("Couldn't parse date string [{}]".format(s)) + raise ValueError("Couldn't parse date string [{}]".format(human_readable)) return dttm -def dttm_from_timetuple(d: struct_time) -> datetime: - return datetime(d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec) +def dttm_from_timetuple(date_: struct_time) -> datetime: + return datetime( + date_.tm_year, + date_.tm_mon, + date_.tm_mday, + date_.tm_hour, + date_.tm_min, + date_.tm_sec, + ) def md5_hex(data: str) -> str: @@ -302,13 +309,13 @@ class DashboardEncoder(json.JSONEncoder): try: vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"} return {"__{}__".format(o.__class__.__name__): vals} - except Exception: - if type(o) == datetime: + except Exception: # pylint: disable=broad-except + if isinstance(o, datetime): return {"__datetime__": o.replace(microsecond=0).isoformat()} return json.JSONEncoder(sort_keys=True).default(o) -def parse_human_timedelta(s: Optional[str]) -> timedelta: +def parse_human_timedelta(human_readable: Optional[str]) -> timedelta: """ Returns ``datetime.datetime`` from natural language time deltas @@ -317,9 +324,16 @@ def parse_human_timedelta(s: Optional[str]) -> timedelta: """ cal = parsedatetime.Calendar() dttm = dttm_from_timetuple(datetime.now().timetuple()) - d = cal.parse(s or "", dttm)[0] - d = datetime(d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec) - return d - dttm + date_ = cal.parse(human_readable or "", dttm)[0] + date_ = datetime( + date_.tm_year, + date_.tm_mon, + date_.tm_mday, + date_.tm_hour, + date_.tm_min, + date_.tm_sec, + ) + return date_ - dttm def parse_past_timedelta(delta_str: str) -> timedelta: @@ -336,7 +350,7 @@ def parse_past_timedelta(delta_str: str) -> timedelta: ) -class JSONEncodedDict(TypeDecorator): +class JSONEncodedDict(TypeDecorator): # pylint: disable=abstract-method """Represents an immutable structure as a json-encoded string.""" impl = TEXT @@ -352,7 +366,7 @@ class JSONEncodedDict(TypeDecorator): return json.loads(value) if value is not None else None -def format_timedelta(td: timedelta) -> str: +def format_timedelta(time_delta: timedelta) -> str: """ Ensures negative time deltas are easily interpreted by humans @@ -362,34 +376,36 @@ def format_timedelta(td: timedelta) -> str: >>> format_timedelta(td) '-1 day, 5:06:00' """ - if td < timedelta(0): - return "-" + str(abs(td)) - else: - # Change this to format positive time deltas the way you want - return str(td) + if time_delta < timedelta(0): + return "-" + str(abs(time_delta)) + + # Change this to format positive time deltas the way you want + return str(time_delta) -def base_json_conv(obj: Any) -> Any: +def base_json_conv( # pylint: disable=inconsistent-return-statements,too-many-return-statements + obj: Any, +) -> Any: if isinstance(obj, memoryview): obj = obj.tobytes() if isinstance(obj, np.int64): return int(obj) - elif isinstance(obj, np.bool_): + if isinstance(obj, np.bool_): return bool(obj) - elif isinstance(obj, np.ndarray): + if isinstance(obj, np.ndarray): return obj.tolist() - elif isinstance(obj, set): + if isinstance(obj, set): return list(obj) - elif isinstance(obj, decimal.Decimal): + if isinstance(obj, decimal.Decimal): return float(obj) - elif isinstance(obj, uuid.UUID): + if isinstance(obj, uuid.UUID): return str(obj) - elif isinstance(obj, timedelta): + if isinstance(obj, timedelta): return format_timedelta(obj) - elif isinstance(obj, bytes): + if isinstance(obj, bytes): try: return obj.decode("utf-8") - except Exception: + except Exception: # pylint: disable=broad-except return "[bytes]" @@ -409,10 +425,8 @@ def json_iso_dttm_ser(obj: Any, pessimistic: bool = False) -> str: else: if pessimistic: return "Unserializable [{}]".format(type(obj)) - else: - raise TypeError( - "Unserializable object {} of type {}".format(obj, type(obj)) - ) + + raise TypeError("Unserializable object {} of type {}".format(obj, type(obj))) return obj @@ -464,7 +478,7 @@ def error_msg_from_exception(ex: Exception) -> str: return msg or str(ex) -def markdown(s: str, markup_wrap: Optional[bool] = False) -> str: +def markdown(raw: str, markup_wrap: Optional[bool] = False) -> str: safe_markdown_tags = [ "h1", "h2", @@ -496,18 +510,18 @@ def markdown(s: str, markup_wrap: Optional[bool] = False) -> str: "img": ["src", "alt", "title"], "a": ["href", "alt", "title"], } - s = md.markdown( - s or "", + safe = md.markdown( + raw or "", extensions=[ "markdown.extensions.tables", "markdown.extensions.fenced_code", "markdown.extensions.codehilite", ], ) - s = bleach.clean(s, safe_markdown_tags, safe_markdown_attrs) + safe = bleach.clean(safe, safe_markdown_tags, safe_markdown_attrs) if markup_wrap: - s = Markup(s) - return s + safe = Markup(safe) + return safe def readfile(file_path: str) -> Optional[str]: @@ -517,19 +531,21 @@ def readfile(file_path: str) -> Optional[str]: def generic_find_constraint_name( - table: str, columns: Set[str], referenced: str, db: SQLA + table: str, columns: Set[str], referenced: str, database: SQLA ) -> Optional[str]: """Utility to find a constraint name in alembic migrations""" - t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) + tbl = sa.Table( + table, database.metadata, autoload=True, autoload_with=database.engine + ) - for fk in t.foreign_key_constraints: + for fk in tbl.foreign_key_constraints: if fk.referred_table.name == referenced and set(fk.column_keys) == columns: return fk.name return None -def generic_find_fk_constraint_name( +def generic_find_fk_constraint_name( # pylint: disable=invalid-name table: str, columns: Set[str], referenced: str, insp: Inspector ) -> Optional[str]: """Utility to find a foreign-key constraint name in alembic migrations""" @@ -543,7 +559,7 @@ def generic_find_fk_constraint_name( return None -def generic_find_fk_constraint_names( +def generic_find_fk_constraint_names( # pylint: disable=invalid-name table: str, columns: Set[str], referenced: str, insp: Inspector ) -> Set[str]: """Utility to find foreign-key constraint names in alembic migrations""" @@ -584,11 +600,11 @@ def validate_json(obj: Union[bytes, bytearray, str]) -> None: try: json.loads(obj) except Exception as ex: - logger.error(f"JSON is not valid {ex}") + logger.error("JSON is not valid %s", str(ex)) raise SupersetException("JSON is not valid") -class timeout: +class timeout: # pylint: disable=invalid-name """ To be used in a ``with`` block and timeout its content. """ @@ -597,7 +613,9 @@ class timeout: self.seconds = seconds self.error_message = error_message - def handle_timeout(self, signum: int, frame: Any) -> None: + def handle_timeout( # pylint: disable=unused-argument + self, signum: int, frame: Any + ) -> None: logger.error("Process timed out") raise SupersetTimeoutException(self.error_message) @@ -609,7 +627,9 @@ class timeout: logger.warning("timeout can't be used in the current context") logger.exception(ex) - def __exit__(self, type: Any, value: Any, traceback: TracebackType) -> None: + def __exit__( # pylint: disable=redefined-outer-name,unused-variable,redefined-builtin + self, type: Any, value: Any, traceback: TracebackType + ) -> None: try: signal.alarm(0) except ValueError as ex: @@ -619,7 +639,9 @@ class timeout: def pessimistic_connection_handling(some_engine: Engine) -> None: @event.listens_for(some_engine, "engine_connect") - def ping_connection(connection: Connection, branch: bool) -> None: + def ping_connection( # pylint: disable=unused-variable + connection: Connection, branch: bool + ) -> None: if branch: # 'branch' refers to a sub-connection of a connection, # we don't want to bother pinging on these. @@ -654,7 +676,7 @@ def pessimistic_connection_handling(some_engine: Engine) -> None: connection.should_close_with_result = save_should_close_with_result -class QueryStatus: +class QueryStatus: # pylint: disable=too-few-public-methods """Enum-type class for query statuses""" STOPPED: str = "stopped" @@ -666,7 +688,7 @@ class QueryStatus: TIMED_OUT: str = "timed_out" -def notify_user_about_perm_udate( +def notify_user_about_perm_udate( # pylint: disable=too-many-arguments granter: User, user: User, role: Role, @@ -692,7 +714,7 @@ def notify_user_about_perm_udate( ) -def send_email_smtp( +def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many-locals to: str, subject: str, html_content: str, @@ -762,36 +784,36 @@ def send_email_smtp( image.add_header("Content-Disposition", "inline") msg.attach(image) - send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun) + send_mime_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun) -def send_MIME_email( +def send_mime_email( e_from: str, e_to: List[str], mime_msg: MIMEMultipart, config: Dict[str, Any], dryrun: bool = False, ) -> None: - SMTP_HOST = config["SMTP_HOST"] - SMTP_PORT = config["SMTP_PORT"] - SMTP_USER = config["SMTP_USER"] - SMTP_PASSWORD = config["SMTP_PASSWORD"] - SMTP_STARTTLS = config["SMTP_STARTTLS"] - SMTP_SSL = config["SMTP_SSL"] + smtp_host = config["SMTP_HOST"] + smtp_port = config["SMTP_PORT"] + smtp_user = config["SMTP_USER"] + smtp_password = config["SMTP_PASSWORD"] + smtp_starttls = config["SMTP_STARTTLS"] + smtp_ssl = config["SMTP_SSL"] if not dryrun: - s = ( - smtplib.SMTP_SSL(SMTP_HOST, SMTP_PORT) - if SMTP_SSL - else smtplib.SMTP(SMTP_HOST, SMTP_PORT) + smtp = ( + smtplib.SMTP_SSL(smtp_host, smtp_port) + if smtp_ssl + else smtplib.SMTP(smtp_host, smtp_port) ) - if SMTP_STARTTLS: - s.starttls() - if SMTP_USER and SMTP_PASSWORD: - s.login(SMTP_USER, SMTP_PASSWORD) - logger.info("Sent an email to " + str(e_to)) - s.sendmail(e_from, e_to, mime_msg.as_string()) - s.quit() + if smtp_starttls: + smtp.starttls() + if smtp_user and smtp_password: + smtp.login(smtp_user, smtp_password) + logger.info("Sent an email to %s", str(e_to)) + smtp.sendmail(e_from, e_to, mime_msg.as_string()) + smtp.quit() else: logger.info("Dryrun enabled, email notification content is below:") logger.info(mime_msg.as_string()) @@ -800,7 +822,7 @@ def send_MIME_email( def get_email_address_list(address_string: str) -> List[str]: address_string_list: List[str] = [] if isinstance(address_string, str): - address_string_list = re.split(",|\s|;", address_string) + address_string_list = re.split(r",|\s|;", address_string) return [x.strip() for x in address_string_list if x.strip()] @@ -837,16 +859,16 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes, def to_adhoc( - filt: Dict[str, Any], expressionType: str = "SIMPLE", clause: str = "where" + filt: Dict[str, Any], expression_type: str = "SIMPLE", clause: str = "where" ) -> Dict[str, Any]: result = { "clause": clause.upper(), - "expressionType": expressionType, + "expressionType": expression_type, "filterOptionName": str(uuid.uuid4()), - "isExtra": True if filt.get("isExtra") is True else False, + "isExtra": bool(filt.get("isExtra")), } - if expressionType == "SIMPLE": + if expression_type == "SIMPLE": result.update( { "comparator": filt.get("val"), @@ -854,13 +876,15 @@ def to_adhoc( "subject": filt.get("col"), } ) - elif expressionType == "SQL": + elif expression_type == "SQL": result.update({"sqlExpression": filt.get(clause)}) return result -def merge_extra_filters(form_data: Dict[str, Any]) -> None: +def merge_extra_filters( # pylint: disable=too-many-branches + form_data: Dict[str, Any] +) -> None: # extra_filters are temporary/contextual filters (using the legacy constructs) # that are external to the slice definition. We use those for dynamic # interactive filters like the ones emitted by the "Filter Box" visualization. @@ -886,8 +910,8 @@ def merge_extra_filters(form_data: Dict[str, Any]) -> None: def get_filter_key(f: Dict[str, Any]) -> str: if "expressionType" in f: return "{}__{}".format(f["subject"], f["operator"]) - else: - return "{}__{}".format(f["col"], f["op"]) + + return "{}__{}".format(f["col"], f["op"]) existing_filters = {} for existing in form_data["adhoc_filters"]: @@ -898,7 +922,9 @@ def merge_extra_filters(form_data: Dict[str, Any]) -> None: ): existing_filters[get_filter_key(existing)] = existing["comparator"] - for filtr in form_data["extra_filters"]: + for filtr in form_data[ # pylint: disable=too-many-nested-blocks + "extra_filters" + ]: filtr["isExtra"] = True # Pull out time filters/options and merge into form data if date_options.get(filtr["col"]): @@ -950,8 +976,8 @@ def user_label(user: User) -> Optional[str]: if user: if user.first_name and user.last_name: return user.first_name + " " + user.last_name - else: - return user.username + + return user.username return None @@ -967,7 +993,7 @@ def get_or_create_db( ) if not database: - logger.info(f"Creating database reference for {database_name}") + logger.info("Creating database reference for %s", database_name) database = models.Database(database_name=database_name, *args, **kwargs) db.session.add(database) @@ -1017,7 +1043,7 @@ def ensure_path_exists(path: str) -> None: raise -def get_since_until( +def get_since_until( # pylint: disable=too-many-arguments time_range: Optional[str] = None, since: Optional[str] = None, until: Optional[str] = None, @@ -1050,8 +1076,12 @@ def get_since_until( """ separator = " : " - relative_start = parse_human_datetime(relative_start if relative_start else "today") # type: ignore - relative_end = parse_human_datetime(relative_end if relative_end else "today") # type: ignore + relative_start = parse_human_datetime( # type: ignore + relative_start if relative_start else "today" + ) + relative_end = parse_human_datetime( # type: ignore + relative_end if relative_end else "today" + ) common_time_frames = { "Last day": ( relative_start - relativedelta(days=1), # type: ignore @@ -1132,33 +1162,37 @@ def add_ago_to_since(since: str) -> str: return since -def convert_legacy_filters_into_adhoc(fd: FormData) -> None: +def convert_legacy_filters_into_adhoc( # pylint: disable=invalid-name + form_data: FormData, +) -> None: mapping = {"having": "having_filters", "where": "filters"} - if not fd.get("adhoc_filters"): - fd["adhoc_filters"] = [] + if not form_data.get("adhoc_filters"): + form_data["adhoc_filters"] = [] for clause, filters in mapping.items(): - if clause in fd and fd[clause] != "": - fd["adhoc_filters"].append(to_adhoc(fd, "SQL", clause)) + if clause in form_data and form_data[clause] != "": + form_data["adhoc_filters"].append(to_adhoc(form_data, "SQL", clause)) - if filters in fd: - for filt in filter(lambda x: x is not None, fd[filters]): - fd["adhoc_filters"].append(to_adhoc(filt, "SIMPLE", clause)) + if filters in form_data: + for filt in filter(lambda x: x is not None, form_data[filters]): + form_data["adhoc_filters"].append(to_adhoc(filt, "SIMPLE", clause)) for key in ("filters", "having", "having_filters", "where"): - if key in fd: - del fd[key] + if key in form_data: + del form_data[key] -def split_adhoc_filters_into_base_filters(fd: FormData) -> None: +def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name + form_data: FormData, +) -> None: """ Mutates form data to restructure the adhoc filters in the form of the four base filters, `where`, `having`, `filters`, and `having_filters` which represent free form where sql, free form having sql, structured where clauses and structured having clauses. """ - adhoc_filters = fd.get("adhoc_filters") + adhoc_filters = form_data.get("adhoc_filters") if isinstance(adhoc_filters, list): simple_where_filters = [] simple_having_filters = [] @@ -1189,17 +1223,21 @@ def split_adhoc_filters_into_base_filters(fd: FormData) -> None: sql_where_filters.append(adhoc_filter.get("sqlExpression")) elif clause == "HAVING": sql_having_filters.append(adhoc_filter.get("sqlExpression")) - fd["where"] = " AND ".join(["({})".format(sql) for sql in sql_where_filters]) - fd["having"] = " AND ".join(["({})".format(sql) for sql in sql_having_filters]) - fd["having_filters"] = simple_having_filters - fd["filters"] = simple_where_filters + form_data["where"] = " AND ".join( + ["({})".format(sql) for sql in sql_where_filters] + ) + form_data["having"] = " AND ".join( + ["({})".format(sql) for sql in sql_having_filters] + ) + form_data["having_filters"] = simple_having_filters + form_data["filters"] = simple_where_filters def get_username() -> Optional[str]: """Get username if within the flask context, otherwise return noffin'""" try: return g.user.username - except Exception: + except Exception: # pylint: disable=broad-except return None @@ -1260,7 +1298,7 @@ def time_function( return (stop - start) * 1000.0, response -def MediumText() -> Variant: +def MediumText() -> Variant: # pylint:disable=invalid-name return Text().with_variant(MEDIUMTEXT(), "mysql") @@ -1280,12 +1318,12 @@ def get_stacktrace() -> Optional[str]: def split( - s: str, delimiter: str = " ", quote: str = '"', escaped_quote: str = r"\"" + string: str, delimiter: str = " ", quote: str = '"', escaped_quote: str = r"\"" ) -> Iterator[str]: """ A split function that is aware of quotes and parentheses. - :param s: string to split + :param string: string to split :param delimiter: string defining where to split, usually a comma or space :param quote: string, either a single or a double quote :param escaped_quote: string representing an escaped quote @@ -1294,21 +1332,21 @@ def split( parens = 0 quotes = False i = 0 - for j, c in enumerate(s): + for j, character in enumerate(string): complete = parens == 0 and not quotes - if complete and c == delimiter: - yield s[i:j] + if complete and character == delimiter: + yield string[i:j] i = j + len(delimiter) - elif c == "(": + elif character == "(": parens += 1 - elif c == ")": + elif character == ")": parens -= 1 - elif c == quote: - if quotes and s[j - len(escaped_quote) + 1 : j + 1] != escaped_quote: + elif character == quote: + if quotes and string[j - len(escaped_quote) + 1 : j + 1] != escaped_quote: quotes = False elif not quotes: quotes = True - yield s[i:] + yield string[i:] def get_iterable(x: Any) -> List[Any]: @@ -1382,7 +1420,7 @@ class FilterOperator(str, Enum): LIKE = "LIKE" IS_NULL = "IS NULL" IS_NOT_NULL = "IS NOT NULL" - IN = "IN" + IN = "IN" # pylint: disable=invalid-name NOT_IN = "NOT IN" REGEX = "REGEX" diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index a58635d22..4d9e0496b 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=C,R,W import logging from typing import Any, Dict, List, Optional diff --git a/tests/access_tests.py b/tests/access_tests.py index f5ac8e09e..0af741529 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -357,7 +357,7 @@ class RequestAccessTests(SupersetTestCase): session.commit() - @mock.patch("superset.utils.core.send_MIME_email") + @mock.patch("superset.utils.core.send_mime_email") def test_approve(self, mock_send_mime): if app.config["ENABLE_ACCESS_REQUEST"]: session = db.session diff --git a/tests/email_tests.py b/tests/email_tests.py index dfbbda0dd..59af211fb 100644 --- a/tests/email_tests.py +++ b/tests/email_tests.py @@ -38,7 +38,7 @@ class EmailSmtpTest(SupersetTestCase): def setUp(self): app.config["smtp_ssl"] = False - @mock.patch("superset.utils.core.send_MIME_email") + @mock.patch("superset.utils.core.send_mime_email") def test_send_smtp(self, mock_send_mime): attachment = tempfile.NamedTemporaryFile() attachment.write(b"attachment") @@ -58,7 +58,7 @@ class EmailSmtpTest(SupersetTestCase): mimeapp = MIMEApplication("attachment") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() - @mock.patch("superset.utils.core.send_MIME_email") + @mock.patch("superset.utils.core.send_mime_email") def test_send_smtp_data(self, mock_send_mime): utils.send_email_smtp( "to", "subject", "content", app.config, data={"1.txt": b"data"} @@ -75,7 +75,7 @@ class EmailSmtpTest(SupersetTestCase): mimeapp = MIMEApplication("data") assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() - @mock.patch("superset.utils.core.send_MIME_email") + @mock.patch("superset.utils.core.send_mime_email") def test_send_smtp_inline_images(self, mock_send_mime): image = read_fixture("sample.png") utils.send_email_smtp( @@ -93,7 +93,7 @@ class EmailSmtpTest(SupersetTestCase): mimeapp = MIMEImage(image) assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload() - @mock.patch("superset.utils.core.send_MIME_email") + @mock.patch("superset.utils.core.send_mime_email") def test_send_bcc_smtp(self, mock_send_mime): attachment = tempfile.NamedTemporaryFile() attachment.write(b"attachment") @@ -124,7 +124,7 @@ class EmailSmtpTest(SupersetTestCase): mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() msg = MIMEMultipart() - utils.send_MIME_email("from", "to", msg, app.config, dryrun=False) + utils.send_mime_email("from", "to", msg, app.config, dryrun=False) mock_smtp.assert_called_with(app.config["SMTP_HOST"], app.config["SMTP_PORT"]) assert mock_smtp.return_value.starttls.called mock_smtp.return_value.login.assert_called_with( @@ -141,7 +141,7 @@ class EmailSmtpTest(SupersetTestCase): app.config["SMTP_SSL"] = True mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() - utils.send_MIME_email("from", "to", MIMEMultipart(), app.config, dryrun=False) + utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False) assert not mock_smtp.called mock_smtp_ssl.assert_called_with( app.config["SMTP_HOST"], app.config["SMTP_PORT"] @@ -154,7 +154,7 @@ class EmailSmtpTest(SupersetTestCase): app.config["SMTP_PASSWORD"] = None mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() - utils.send_MIME_email("from", "to", MIMEMultipart(), app.config, dryrun=False) + utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False) assert not mock_smtp_ssl.called mock_smtp.assert_called_with(app.config["SMTP_HOST"], app.config["SMTP_PORT"]) assert not mock_smtp.login.called @@ -162,7 +162,7 @@ class EmailSmtpTest(SupersetTestCase): @mock.patch("smtplib.SMTP_SSL") @mock.patch("smtplib.SMTP") def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl): - utils.send_MIME_email("from", "to", MIMEMultipart(), app.config, dryrun=True) + utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=True) assert not mock_smtp.called assert not mock_smtp_ssl.called