refactor: Fix lint on superset/utils/core.py (#10120)

* Fix lint on superset/utils/core.py

* black

* mypy

* Fix some missing renames
This commit is contained in:
Will Barrett 2020-06-26 08:49:12 -07:00 committed by GitHub
parent 410c5be2f8
commit df71fac1e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 174 additions and 137 deletions

View File

@ -41,7 +41,7 @@ naming_convention = {
def find_constraint_name(upgrade=True):
cols = {"column_name"} if upgrade else {"datasource_name"}
return generic_find_constraint_name(
table="columns", columns=cols, referenced="datasources", db=db
table="columns", columns=cols, referenced="datasources", database=db
)

View File

@ -45,10 +45,10 @@ def upgrade():
table="slices",
columns={"druid_datasource_id"},
referenced="datasources",
db=db,
database=db,
)
slices_ibfk_2 = generic_find_constraint_name(
table="slices", columns={"table_id"}, referenced="tables", db=db
table="slices", columns={"table_id"}, referenced="tables", database=db
)
with op.batch_alter_table("slices") as batch_op:
@ -119,7 +119,7 @@ def downgrade():
table="columns",
columns={"datasource_name"},
referenced="datasources",
db=db,
database=db,
)
with op.batch_alter_table("columns") as batch_op:
batch_op.drop_constraint(fk_columns, type_="foreignkey")

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
"""Utility functions used across Superset"""
import decimal
import errno
@ -175,10 +174,12 @@ class _memoized:
"""Return the function's docstring."""
return self.func.__doc__ or ""
def __get__(self, obj: Any, objtype: Type[Any]) -> functools.partial: # type: ignore
def __get__(
self, obj: Any, objtype: Type[Any]
) -> functools.partial: # type: ignore
if not self.is_method:
self.is_method = True
"""Support instance methods."""
# Support instance methods.
return functools.partial(self.__call__, obj)
@ -187,12 +188,11 @@ def memoized(
) -> Callable[..., Any]:
if func:
return _memoized(func)
else:
def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
return _memoized(f, watch)
def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
return _memoized(f, watch)
return wrapper
return wrapper
def parse_js_uri_path_item(
@ -247,7 +247,7 @@ def list_minus(l: List[Any], minus: List[Any]) -> List[Any]:
return [o for o in l if o not in minus]
def parse_human_datetime(s: str) -> datetime:
def parse_human_datetime(human_readable: str) -> datetime:
"""
Returns ``datetime.datetime`` from human readable strings
@ -269,23 +269,30 @@ def parse_human_datetime(s: str) -> datetime:
True
"""
try:
dttm = parse(s)
except Exception:
dttm = parse(human_readable)
except Exception: # pylint: disable=broad-except
try:
cal = parsedatetime.Calendar()
parsed_dttm, parsed_flags = cal.parseDT(s)
parsed_dttm, parsed_flags = cal.parseDT(human_readable)
# when time is not extracted, we 'reset to midnight'
if parsed_flags & 2 == 0:
parsed_dttm = parsed_dttm.replace(hour=0, minute=0, second=0)
dttm = dttm_from_timetuple(parsed_dttm.utctimetuple())
except Exception as ex:
logger.exception(ex)
raise ValueError("Couldn't parse date string [{}]".format(s))
raise ValueError("Couldn't parse date string [{}]".format(human_readable))
return dttm
def dttm_from_timetuple(d: struct_time) -> datetime:
return datetime(d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec)
def dttm_from_timetuple(date_: struct_time) -> datetime:
return datetime(
date_.tm_year,
date_.tm_mon,
date_.tm_mday,
date_.tm_hour,
date_.tm_min,
date_.tm_sec,
)
def md5_hex(data: str) -> str:
@ -302,13 +309,13 @@ class DashboardEncoder(json.JSONEncoder):
try:
vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"}
return {"__{}__".format(o.__class__.__name__): vals}
except Exception:
if type(o) == datetime:
except Exception: # pylint: disable=broad-except
if isinstance(o, datetime):
return {"__datetime__": o.replace(microsecond=0).isoformat()}
return json.JSONEncoder(sort_keys=True).default(o)
def parse_human_timedelta(s: Optional[str]) -> timedelta:
def parse_human_timedelta(human_readable: Optional[str]) -> timedelta:
"""
Returns ``datetime.datetime`` from natural language time deltas
@ -317,9 +324,16 @@ def parse_human_timedelta(s: Optional[str]) -> timedelta:
"""
cal = parsedatetime.Calendar()
dttm = dttm_from_timetuple(datetime.now().timetuple())
d = cal.parse(s or "", dttm)[0]
d = datetime(d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec)
return d - dttm
date_ = cal.parse(human_readable or "", dttm)[0]
date_ = datetime(
date_.tm_year,
date_.tm_mon,
date_.tm_mday,
date_.tm_hour,
date_.tm_min,
date_.tm_sec,
)
return date_ - dttm
def parse_past_timedelta(delta_str: str) -> timedelta:
@ -336,7 +350,7 @@ def parse_past_timedelta(delta_str: str) -> timedelta:
)
class JSONEncodedDict(TypeDecorator):
class JSONEncodedDict(TypeDecorator): # pylint: disable=abstract-method
"""Represents an immutable structure as a json-encoded string."""
impl = TEXT
@ -352,7 +366,7 @@ class JSONEncodedDict(TypeDecorator):
return json.loads(value) if value is not None else None
def format_timedelta(td: timedelta) -> str:
def format_timedelta(time_delta: timedelta) -> str:
"""
Ensures negative time deltas are easily interpreted by humans
@ -362,34 +376,36 @@ def format_timedelta(td: timedelta) -> str:
>>> format_timedelta(td)
'-1 day, 5:06:00'
"""
if td < timedelta(0):
return "-" + str(abs(td))
else:
# Change this to format positive time deltas the way you want
return str(td)
if time_delta < timedelta(0):
return "-" + str(abs(time_delta))
# Change this to format positive time deltas the way you want
return str(time_delta)
def base_json_conv(obj: Any) -> Any:
def base_json_conv( # pylint: disable=inconsistent-return-statements,too-many-return-statements
obj: Any,
) -> Any:
if isinstance(obj, memoryview):
obj = obj.tobytes()
if isinstance(obj, np.int64):
return int(obj)
elif isinstance(obj, np.bool_):
if isinstance(obj, np.bool_):
return bool(obj)
elif isinstance(obj, np.ndarray):
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, set):
if isinstance(obj, set):
return list(obj)
elif isinstance(obj, decimal.Decimal):
if isinstance(obj, decimal.Decimal):
return float(obj)
elif isinstance(obj, uuid.UUID):
if isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, timedelta):
if isinstance(obj, timedelta):
return format_timedelta(obj)
elif isinstance(obj, bytes):
if isinstance(obj, bytes):
try:
return obj.decode("utf-8")
except Exception:
except Exception: # pylint: disable=broad-except
return "[bytes]"
@ -409,10 +425,8 @@ def json_iso_dttm_ser(obj: Any, pessimistic: bool = False) -> str:
else:
if pessimistic:
return "Unserializable [{}]".format(type(obj))
else:
raise TypeError(
"Unserializable object {} of type {}".format(obj, type(obj))
)
raise TypeError("Unserializable object {} of type {}".format(obj, type(obj)))
return obj
@ -464,7 +478,7 @@ def error_msg_from_exception(ex: Exception) -> str:
return msg or str(ex)
def markdown(s: str, markup_wrap: Optional[bool] = False) -> str:
def markdown(raw: str, markup_wrap: Optional[bool] = False) -> str:
safe_markdown_tags = [
"h1",
"h2",
@ -496,18 +510,18 @@ def markdown(s: str, markup_wrap: Optional[bool] = False) -> str:
"img": ["src", "alt", "title"],
"a": ["href", "alt", "title"],
}
s = md.markdown(
s or "",
safe = md.markdown(
raw or "",
extensions=[
"markdown.extensions.tables",
"markdown.extensions.fenced_code",
"markdown.extensions.codehilite",
],
)
s = bleach.clean(s, safe_markdown_tags, safe_markdown_attrs)
safe = bleach.clean(safe, safe_markdown_tags, safe_markdown_attrs)
if markup_wrap:
s = Markup(s)
return s
safe = Markup(safe)
return safe
def readfile(file_path: str) -> Optional[str]:
@ -517,19 +531,21 @@ def readfile(file_path: str) -> Optional[str]:
def generic_find_constraint_name(
table: str, columns: Set[str], referenced: str, db: SQLA
table: str, columns: Set[str], referenced: str, database: SQLA
) -> Optional[str]:
"""Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
tbl = sa.Table(
table, database.metadata, autoload=True, autoload_with=database.engine
)
for fk in t.foreign_key_constraints:
for fk in tbl.foreign_key_constraints:
if fk.referred_table.name == referenced and set(fk.column_keys) == columns:
return fk.name
return None
def generic_find_fk_constraint_name(
def generic_find_fk_constraint_name( # pylint: disable=invalid-name
table: str, columns: Set[str], referenced: str, insp: Inspector
) -> Optional[str]:
"""Utility to find a foreign-key constraint name in alembic migrations"""
@ -543,7 +559,7 @@ def generic_find_fk_constraint_name(
return None
def generic_find_fk_constraint_names(
def generic_find_fk_constraint_names( # pylint: disable=invalid-name
table: str, columns: Set[str], referenced: str, insp: Inspector
) -> Set[str]:
"""Utility to find foreign-key constraint names in alembic migrations"""
@ -584,11 +600,11 @@ def validate_json(obj: Union[bytes, bytearray, str]) -> None:
try:
json.loads(obj)
except Exception as ex:
logger.error(f"JSON is not valid {ex}")
logger.error("JSON is not valid %s", str(ex))
raise SupersetException("JSON is not valid")
class timeout:
class timeout: # pylint: disable=invalid-name
"""
To be used in a ``with`` block and timeout its content.
"""
@ -597,7 +613,9 @@ class timeout:
self.seconds = seconds
self.error_message = error_message
def handle_timeout(self, signum: int, frame: Any) -> None:
def handle_timeout( # pylint: disable=unused-argument
self, signum: int, frame: Any
) -> None:
logger.error("Process timed out")
raise SupersetTimeoutException(self.error_message)
@ -609,7 +627,9 @@ class timeout:
logger.warning("timeout can't be used in the current context")
logger.exception(ex)
def __exit__(self, type: Any, value: Any, traceback: TracebackType) -> None:
def __exit__( # pylint: disable=redefined-outer-name,unused-variable,redefined-builtin
self, type: Any, value: Any, traceback: TracebackType
) -> None:
try:
signal.alarm(0)
except ValueError as ex:
@ -619,7 +639,9 @@ class timeout:
def pessimistic_connection_handling(some_engine: Engine) -> None:
@event.listens_for(some_engine, "engine_connect")
def ping_connection(connection: Connection, branch: bool) -> None:
def ping_connection( # pylint: disable=unused-variable
connection: Connection, branch: bool
) -> None:
if branch:
# 'branch' refers to a sub-connection of a connection,
# we don't want to bother pinging on these.
@ -654,7 +676,7 @@ def pessimistic_connection_handling(some_engine: Engine) -> None:
connection.should_close_with_result = save_should_close_with_result
class QueryStatus:
class QueryStatus: # pylint: disable=too-few-public-methods
"""Enum-type class for query statuses"""
STOPPED: str = "stopped"
@ -666,7 +688,7 @@ class QueryStatus:
TIMED_OUT: str = "timed_out"
def notify_user_about_perm_udate(
def notify_user_about_perm_udate( # pylint: disable=too-many-arguments
granter: User,
user: User,
role: Role,
@ -692,7 +714,7 @@ def notify_user_about_perm_udate(
)
def send_email_smtp(
def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many-locals
to: str,
subject: str,
html_content: str,
@ -762,36 +784,36 @@ def send_email_smtp(
image.add_header("Content-Disposition", "inline")
msg.attach(image)
send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun)
send_mime_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun)
def send_MIME_email(
def send_mime_email(
e_from: str,
e_to: List[str],
mime_msg: MIMEMultipart,
config: Dict[str, Any],
dryrun: bool = False,
) -> None:
SMTP_HOST = config["SMTP_HOST"]
SMTP_PORT = config["SMTP_PORT"]
SMTP_USER = config["SMTP_USER"]
SMTP_PASSWORD = config["SMTP_PASSWORD"]
SMTP_STARTTLS = config["SMTP_STARTTLS"]
SMTP_SSL = config["SMTP_SSL"]
smtp_host = config["SMTP_HOST"]
smtp_port = config["SMTP_PORT"]
smtp_user = config["SMTP_USER"]
smtp_password = config["SMTP_PASSWORD"]
smtp_starttls = config["SMTP_STARTTLS"]
smtp_ssl = config["SMTP_SSL"]
if not dryrun:
s = (
smtplib.SMTP_SSL(SMTP_HOST, SMTP_PORT)
if SMTP_SSL
else smtplib.SMTP(SMTP_HOST, SMTP_PORT)
smtp = (
smtplib.SMTP_SSL(smtp_host, smtp_port)
if smtp_ssl
else smtplib.SMTP(smtp_host, smtp_port)
)
if SMTP_STARTTLS:
s.starttls()
if SMTP_USER and SMTP_PASSWORD:
s.login(SMTP_USER, SMTP_PASSWORD)
logger.info("Sent an email to " + str(e_to))
s.sendmail(e_from, e_to, mime_msg.as_string())
s.quit()
if smtp_starttls:
smtp.starttls()
if smtp_user and smtp_password:
smtp.login(smtp_user, smtp_password)
logger.info("Sent an email to %s", str(e_to))
smtp.sendmail(e_from, e_to, mime_msg.as_string())
smtp.quit()
else:
logger.info("Dryrun enabled, email notification content is below:")
logger.info(mime_msg.as_string())
@ -800,7 +822,7 @@ def send_MIME_email(
def get_email_address_list(address_string: str) -> List[str]:
address_string_list: List[str] = []
if isinstance(address_string, str):
address_string_list = re.split(",|\s|;", address_string)
address_string_list = re.split(r",|\s|;", address_string)
return [x.strip() for x in address_string_list if x.strip()]
@ -837,16 +859,16 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes,
def to_adhoc(
filt: Dict[str, Any], expressionType: str = "SIMPLE", clause: str = "where"
filt: Dict[str, Any], expression_type: str = "SIMPLE", clause: str = "where"
) -> Dict[str, Any]:
result = {
"clause": clause.upper(),
"expressionType": expressionType,
"expressionType": expression_type,
"filterOptionName": str(uuid.uuid4()),
"isExtra": True if filt.get("isExtra") is True else False,
"isExtra": bool(filt.get("isExtra")),
}
if expressionType == "SIMPLE":
if expression_type == "SIMPLE":
result.update(
{
"comparator": filt.get("val"),
@ -854,13 +876,15 @@ def to_adhoc(
"subject": filt.get("col"),
}
)
elif expressionType == "SQL":
elif expression_type == "SQL":
result.update({"sqlExpression": filt.get(clause)})
return result
def merge_extra_filters(form_data: Dict[str, Any]) -> None:
def merge_extra_filters( # pylint: disable=too-many-branches
form_data: Dict[str, Any]
) -> None:
# extra_filters are temporary/contextual filters (using the legacy constructs)
# that are external to the slice definition. We use those for dynamic
# interactive filters like the ones emitted by the "Filter Box" visualization.
@ -886,8 +910,8 @@ def merge_extra_filters(form_data: Dict[str, Any]) -> None:
def get_filter_key(f: Dict[str, Any]) -> str:
if "expressionType" in f:
return "{}__{}".format(f["subject"], f["operator"])
else:
return "{}__{}".format(f["col"], f["op"])
return "{}__{}".format(f["col"], f["op"])
existing_filters = {}
for existing in form_data["adhoc_filters"]:
@ -898,7 +922,9 @@ def merge_extra_filters(form_data: Dict[str, Any]) -> None:
):
existing_filters[get_filter_key(existing)] = existing["comparator"]
for filtr in form_data["extra_filters"]:
for filtr in form_data[ # pylint: disable=too-many-nested-blocks
"extra_filters"
]:
filtr["isExtra"] = True
# Pull out time filters/options and merge into form data
if date_options.get(filtr["col"]):
@ -950,8 +976,8 @@ def user_label(user: User) -> Optional[str]:
if user:
if user.first_name and user.last_name:
return user.first_name + " " + user.last_name
else:
return user.username
return user.username
return None
@ -967,7 +993,7 @@ def get_or_create_db(
)
if not database:
logger.info(f"Creating database reference for {database_name}")
logger.info("Creating database reference for %s", database_name)
database = models.Database(database_name=database_name, *args, **kwargs)
db.session.add(database)
@ -1017,7 +1043,7 @@ def ensure_path_exists(path: str) -> None:
raise
def get_since_until(
def get_since_until( # pylint: disable=too-many-arguments
time_range: Optional[str] = None,
since: Optional[str] = None,
until: Optional[str] = None,
@ -1050,8 +1076,12 @@ def get_since_until(
"""
separator = " : "
relative_start = parse_human_datetime(relative_start if relative_start else "today") # type: ignore
relative_end = parse_human_datetime(relative_end if relative_end else "today") # type: ignore
relative_start = parse_human_datetime( # type: ignore
relative_start if relative_start else "today"
)
relative_end = parse_human_datetime( # type: ignore
relative_end if relative_end else "today"
)
common_time_frames = {
"Last day": (
relative_start - relativedelta(days=1), # type: ignore
@ -1132,33 +1162,37 @@ def add_ago_to_since(since: str) -> str:
return since
def convert_legacy_filters_into_adhoc(fd: FormData) -> None:
def convert_legacy_filters_into_adhoc( # pylint: disable=invalid-name
form_data: FormData,
) -> None:
mapping = {"having": "having_filters", "where": "filters"}
if not fd.get("adhoc_filters"):
fd["adhoc_filters"] = []
if not form_data.get("adhoc_filters"):
form_data["adhoc_filters"] = []
for clause, filters in mapping.items():
if clause in fd and fd[clause] != "":
fd["adhoc_filters"].append(to_adhoc(fd, "SQL", clause))
if clause in form_data and form_data[clause] != "":
form_data["adhoc_filters"].append(to_adhoc(form_data, "SQL", clause))
if filters in fd:
for filt in filter(lambda x: x is not None, fd[filters]):
fd["adhoc_filters"].append(to_adhoc(filt, "SIMPLE", clause))
if filters in form_data:
for filt in filter(lambda x: x is not None, form_data[filters]):
form_data["adhoc_filters"].append(to_adhoc(filt, "SIMPLE", clause))
for key in ("filters", "having", "having_filters", "where"):
if key in fd:
del fd[key]
if key in form_data:
del form_data[key]
def split_adhoc_filters_into_base_filters(fd: FormData) -> None:
def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name
form_data: FormData,
) -> None:
"""
Mutates form data to restructure the adhoc filters in the form of the four base
filters, `where`, `having`, `filters`, and `having_filters` which represent
free form where sql, free form having sql, structured where clauses and structured
having clauses.
"""
adhoc_filters = fd.get("adhoc_filters")
adhoc_filters = form_data.get("adhoc_filters")
if isinstance(adhoc_filters, list):
simple_where_filters = []
simple_having_filters = []
@ -1189,17 +1223,21 @@ def split_adhoc_filters_into_base_filters(fd: FormData) -> None:
sql_where_filters.append(adhoc_filter.get("sqlExpression"))
elif clause == "HAVING":
sql_having_filters.append(adhoc_filter.get("sqlExpression"))
fd["where"] = " AND ".join(["({})".format(sql) for sql in sql_where_filters])
fd["having"] = " AND ".join(["({})".format(sql) for sql in sql_having_filters])
fd["having_filters"] = simple_having_filters
fd["filters"] = simple_where_filters
form_data["where"] = " AND ".join(
["({})".format(sql) for sql in sql_where_filters]
)
form_data["having"] = " AND ".join(
["({})".format(sql) for sql in sql_having_filters]
)
form_data["having_filters"] = simple_having_filters
form_data["filters"] = simple_where_filters
def get_username() -> Optional[str]:
"""Get username if within the flask context, otherwise return noffin'"""
try:
return g.user.username
except Exception:
except Exception: # pylint: disable=broad-except
return None
@ -1260,7 +1298,7 @@ def time_function(
return (stop - start) * 1000.0, response
def MediumText() -> Variant:
def MediumText() -> Variant: # pylint:disable=invalid-name
return Text().with_variant(MEDIUMTEXT(), "mysql")
@ -1280,12 +1318,12 @@ def get_stacktrace() -> Optional[str]:
def split(
s: str, delimiter: str = " ", quote: str = '"', escaped_quote: str = r"\""
string: str, delimiter: str = " ", quote: str = '"', escaped_quote: str = r"\""
) -> Iterator[str]:
"""
A split function that is aware of quotes and parentheses.
:param s: string to split
:param string: string to split
:param delimiter: string defining where to split, usually a comma or space
:param quote: string, either a single or a double quote
:param escaped_quote: string representing an escaped quote
@ -1294,21 +1332,21 @@ def split(
parens = 0
quotes = False
i = 0
for j, c in enumerate(s):
for j, character in enumerate(string):
complete = parens == 0 and not quotes
if complete and c == delimiter:
yield s[i:j]
if complete and character == delimiter:
yield string[i:j]
i = j + len(delimiter)
elif c == "(":
elif character == "(":
parens += 1
elif c == ")":
elif character == ")":
parens -= 1
elif c == quote:
if quotes and s[j - len(escaped_quote) + 1 : j + 1] != escaped_quote:
elif character == quote:
if quotes and string[j - len(escaped_quote) + 1 : j + 1] != escaped_quote:
quotes = False
elif not quotes:
quotes = True
yield s[i:]
yield string[i:]
def get_iterable(x: Any) -> List[Any]:
@ -1382,7 +1420,7 @@ class FilterOperator(str, Enum):
LIKE = "LIKE"
IS_NULL = "IS NULL"
IS_NOT_NULL = "IS NOT NULL"
IN = "IN"
IN = "IN" # pylint: disable=invalid-name
NOT_IN = "NOT IN"
REGEX = "REGEX"

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
import logging
from typing import Any, Dict, List, Optional

View File

@ -357,7 +357,7 @@ class RequestAccessTests(SupersetTestCase):
session.commit()
@mock.patch("superset.utils.core.send_MIME_email")
@mock.patch("superset.utils.core.send_mime_email")
def test_approve(self, mock_send_mime):
if app.config["ENABLE_ACCESS_REQUEST"]:
session = db.session

View File

@ -38,7 +38,7 @@ class EmailSmtpTest(SupersetTestCase):
def setUp(self):
app.config["smtp_ssl"] = False
@mock.patch("superset.utils.core.send_MIME_email")
@mock.patch("superset.utils.core.send_mime_email")
def test_send_smtp(self, mock_send_mime):
attachment = tempfile.NamedTemporaryFile()
attachment.write(b"attachment")
@ -58,7 +58,7 @@ class EmailSmtpTest(SupersetTestCase):
mimeapp = MIMEApplication("attachment")
assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload()
@mock.patch("superset.utils.core.send_MIME_email")
@mock.patch("superset.utils.core.send_mime_email")
def test_send_smtp_data(self, mock_send_mime):
utils.send_email_smtp(
"to", "subject", "content", app.config, data={"1.txt": b"data"}
@ -75,7 +75,7 @@ class EmailSmtpTest(SupersetTestCase):
mimeapp = MIMEApplication("data")
assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload()
@mock.patch("superset.utils.core.send_MIME_email")
@mock.patch("superset.utils.core.send_mime_email")
def test_send_smtp_inline_images(self, mock_send_mime):
image = read_fixture("sample.png")
utils.send_email_smtp(
@ -93,7 +93,7 @@ class EmailSmtpTest(SupersetTestCase):
mimeapp = MIMEImage(image)
assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload()
@mock.patch("superset.utils.core.send_MIME_email")
@mock.patch("superset.utils.core.send_mime_email")
def test_send_bcc_smtp(self, mock_send_mime):
attachment = tempfile.NamedTemporaryFile()
attachment.write(b"attachment")
@ -124,7 +124,7 @@ class EmailSmtpTest(SupersetTestCase):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
msg = MIMEMultipart()
utils.send_MIME_email("from", "to", msg, app.config, dryrun=False)
utils.send_mime_email("from", "to", msg, app.config, dryrun=False)
mock_smtp.assert_called_with(app.config["SMTP_HOST"], app.config["SMTP_PORT"])
assert mock_smtp.return_value.starttls.called
mock_smtp.return_value.login.assert_called_with(
@ -141,7 +141,7 @@ class EmailSmtpTest(SupersetTestCase):
app.config["SMTP_SSL"] = True
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
utils.send_MIME_email("from", "to", MIMEMultipart(), app.config, dryrun=False)
utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False)
assert not mock_smtp.called
mock_smtp_ssl.assert_called_with(
app.config["SMTP_HOST"], app.config["SMTP_PORT"]
@ -154,7 +154,7 @@ class EmailSmtpTest(SupersetTestCase):
app.config["SMTP_PASSWORD"] = None
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
utils.send_MIME_email("from", "to", MIMEMultipart(), app.config, dryrun=False)
utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=False)
assert not mock_smtp_ssl.called
mock_smtp.assert_called_with(app.config["SMTP_HOST"], app.config["SMTP_PORT"])
assert not mock_smtp.login.called
@ -162,7 +162,7 @@ class EmailSmtpTest(SupersetTestCase):
@mock.patch("smtplib.SMTP_SSL")
@mock.patch("smtplib.SMTP")
def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl):
utils.send_MIME_email("from", "to", MIMEMultipart(), app.config, dryrun=True)
utils.send_mime_email("from", "to", MIMEMultipart(), app.config, dryrun=True)
assert not mock_smtp.called
assert not mock_smtp_ssl.called