From 1c656feb95c15007c3c5f90b199b721c26bfa0bd Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 24 Apr 2020 10:07:35 -0700 Subject: [PATCH] [mypy] Enforcing typing for a number of modules (#9586) Co-authored-by: John Bodley --- setup.cfg | 2 +- superset/connectors/druid/models.py | 2 + superset/exceptions.py | 2 +- superset/models/schedules.py | 8 +-- superset/security/analytics_db_safety.py | 3 +- superset/security/manager.py | 7 ++- superset/sql_validators/presto_db.py | 3 +- superset/tasks/cache.py | 32 +++++----- superset/tasks/celery_app.py | 2 +- superset/tasks/schedules.py | 77 +++++++++++++++++------- superset/tasks/thumbnails.py | 8 +-- superset/utils/core.py | 54 ++++++++--------- 12 files changed, 121 insertions(+), 79 deletions(-) diff --git a/setup.cfg b/setup.cfg index d58d80a1b..9469118b2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ order_by_type = false ignore_missing_imports = true no_implicit_optional = true -[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*] +[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*] check_untyped_defs = true disallow_untyped_calls = true disallow_untyped_defs = true diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index eef20e215..8b841cc34 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -823,6 +823,7 @@ class DruidDatasource(Model, BaseDatasource): if origin: dttm = utils.parse_human_datetime(origin) + assert dttm granularity["origin"] = dttm.isoformat() if period_name in iso_8601_dict: @@ -978,6 +979,7 @@ class DruidDatasource(Model, BaseDatasource): # TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid if self.fetch_values_from: from_dttm = utils.parse_human_datetime(self.fetch_values_from) + assert from_dttm else: from_dttm = datetime(1970, 1, 1) diff --git a/superset/exceptions.py b/superset/exceptions.py index e7f2e2d40..33841cd3c 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -41,7 +41,7 @@ class SupersetTimeoutException(SupersetException): class SupersetSecurityException(SupersetException): status = 401 - def __init__(self, msg, link=None): + def __init__(self, msg: str, link: Optional[str] = None) -> None: super(SupersetSecurityException, self).__init__(msg) self.link = link diff --git a/superset/models/schedules.py b/superset/models/schedules.py index 5d10b5676..6e2157ff4 100644 --- a/superset/models/schedules.py +++ b/superset/models/schedules.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """Models for scheduled execution of jobs""" - import enum +from typing import Optional, Type from flask_appbuilder import Model from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Text @@ -86,9 +86,9 @@ class SliceEmailSchedule(Model, AuditMixinNullable, ImportMixin, EmailSchedule): email_format = Column(Enum(SliceEmailReportFormat)) -def get_scheduler_model(report_type): - if report_type == ScheduleType.dashboard.value: +def get_scheduler_model(report_type: ScheduleType) -> Optional[Type[EmailSchedule]]: + if report_type == ScheduleType.dashboard: return DashboardEmailSchedule - elif report_type == ScheduleType.slice.value: + elif report_type == ScheduleType.slice: return SliceEmailSchedule return None diff --git a/superset/security/analytics_db_safety.py b/superset/security/analytics_db_safety.py index 64c7711f3..5c6a3f208 100644 --- a/superset/security/analytics_db_safety.py +++ b/superset/security/analytics_db_safety.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from sqlalchemy.engine.url import URL class DBSecurityException(Exception): @@ -22,7 +23,7 @@ class DBSecurityException(Exception): status = 400 -def check_sqlalchemy_uri(uri): +def check_sqlalchemy_uri(uri: URL) -> None: if uri.startswith("sqlite"): # sqlite creates a local DB, which allows mapping server's filesystem raise DBSecurityException( diff --git a/superset/security/manager.py b/superset/security/manager.py index 01c80d6cb..e3b4b1de9 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -38,6 +38,7 @@ from flask_appbuilder.widgets import ListWidget from sqlalchemy import or_ from sqlalchemy.engine.base import Connection from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.orm.query import Query from superset import sql_parse from superset.connectors.connector_registry import ConnectorRegistry @@ -70,7 +71,7 @@ class SupersetRoleListWidget(ListWidget): template = "superset/fab_overrides/list_role.html" - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: kwargs["appbuilder"] = current_app.appbuilder super().__init__(**kwargs) @@ -580,7 +581,7 @@ class SupersetSecurityManager(SecurityManager): if pv.permission and pv.view_menu: all_pvs.add((pv.permission.name, pv.view_menu.name)) - def merge_pv(view_menu, perm): + def merge_pv(view_menu: str, perm: str) -> None: """Create permission view menu only if it doesn't exist""" if view_menu and perm and (view_menu, perm) not in all_pvs: self.add_permission_view_menu(view_menu, perm) @@ -899,7 +900,7 @@ class SupersetSecurityManager(SecurityManager): self.assert_datasource_permission(viz.datasource) - def get_rls_filters(self, table: "BaseDatasource"): + def get_rls_filters(self, table: "BaseDatasource") -> List[Query]: """ Retrieves the appropriate row level security filters for the current user and the passed table. diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index fc5efda27..42e7cffaa 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional from flask import g from superset import app, security_manager +from superset.models.core import Database from superset.sql_parse import ParsedQuery from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation from superset.utils.core import QuerySource @@ -44,7 +45,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): @classmethod def validate_statement( - cls, statement, database, cursor, user_name + cls, statement: str, database: Database, cursor: Any, user_name: str ) -> Optional[SQLValidationAnnotation]: # pylint: disable=too-many-locals db_engine_spec = database.db_engine_spec diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index 67c366ba2..b530deb3f 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -18,7 +18,7 @@ import json import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Union from urllib import request from urllib.error import URLError @@ -38,7 +38,9 @@ logger = get_task_logger(__name__) logger.setLevel(logging.INFO) -def get_form_data(chart_id, dashboard=None): +def get_form_data( + chart_id: int, dashboard: Optional[Dashboard] = None +) -> Dict[str, Any]: """ Build `form_data` for chart GET request from dashboard's `default_filters`. @@ -46,7 +48,7 @@ def get_form_data(chart_id, dashboard=None): filters in the GET request for charts. """ - form_data = {"slice_id": chart_id} + form_data: Dict[str, Any] = {"slice_id": chart_id} if dashboard is None or not dashboard.json_metadata: return form_data @@ -72,7 +74,7 @@ def get_form_data(chart_id, dashboard=None): return form_data -def get_url(chart, extra_filters: Optional[Dict[str, Any]] = None): +def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str: """Return external URL for warming up a given chart/table cache.""" with app.test_request_context(): baseurl = ( @@ -106,10 +108,10 @@ class Strategy: """ - def __init__(self): + def __init__(self) -> None: pass - def get_urls(self): + def get_urls(self) -> List[str]: raise NotImplementedError("Subclasses must implement get_urls!") @@ -131,7 +133,7 @@ class DummyStrategy(Strategy): name = "dummy" - def get_urls(self): + def get_urls(self) -> List[str]: session = db.create_scoped_session() charts = session.query(Slice).all() @@ -158,12 +160,12 @@ class TopNDashboardsStrategy(Strategy): name = "top_n_dashboards" - def __init__(self, top_n=5, since="7 days ago"): + def __init__(self, top_n: int = 5, since: str = "7 days ago") -> None: super(TopNDashboardsStrategy, self).__init__() self.top_n = top_n self.since = parse_human_datetime(since) - def get_urls(self): + def get_urls(self) -> List[str]: urls = [] session = db.create_scoped_session() @@ -203,11 +205,11 @@ class DashboardTagsStrategy(Strategy): name = "dashboard_tags" - def __init__(self, tags=None): + def __init__(self, tags: Optional[List[str]] = None) -> None: super(DashboardTagsStrategy, self).__init__() self.tags = tags or [] - def get_urls(self): + def get_urls(self) -> List[str]: urls = [] session = db.create_scoped_session() @@ -254,7 +256,9 @@ strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy] @celery_app.task(name="cache-warmup") -def cache_warmup(strategy_name, *args, **kwargs): +def cache_warmup( + strategy_name: str, *args: Any, **kwargs: Any +) -> Union[Dict[str, List[str]], str]: """ Warm up cache. @@ -264,7 +268,7 @@ def cache_warmup(strategy_name, *args, **kwargs): logger.info("Loading strategy") class_ = None for class_ in strategies: - if class_.name == strategy_name: + if class_.name == strategy_name: # type: ignore break else: message = f"No strategy {strategy_name} found!" @@ -280,7 +284,7 @@ def cache_warmup(strategy_name, *args, **kwargs): logger.exception(message) return message - results = {"success": [], "errors": []} + results: Dict[str, List[str]] = {"success": [], "errors": []} for url in strategy.get_urls(): try: logger.info(f"Fetching {url}") diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py index 0f3cd0ef5..0344b59f8 100644 --- a/superset/tasks/celery_app.py +++ b/superset/tasks/celery_app.py @@ -25,7 +25,7 @@ from superset import create_app from superset.extensions import celery_app # Init the Flask app / configure everything -create_app() +create_app() # type: ignore # Need to import late, as the celery_app will have been setup by "create_app()" # pylint: disable=wrong-import-position, unused-import diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 3889d021f..45036a8b2 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -23,10 +23,12 @@ import urllib.request from collections import namedtuple from datetime import datetime, timedelta from email.utils import make_msgid, parseaddr +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from urllib.error import URLError # pylint: disable=ungrouped-imports import croniter import simplejson as json +from celery.app.task import Task from dateutil.tz import tzlocal from flask import render_template, Response, session, url_for from flask_babel import gettext as __ @@ -34,16 +36,20 @@ from flask_login import login_user from retry.api import retry_call from selenium.common.exceptions import WebDriverException from selenium.webdriver import chrome, firefox +from werkzeug.datastructures import TypeConversionDict from werkzeug.http import parse_cookie # Superset framework imports from superset import app, db, security_manager from superset.extensions import celery_app from superset.models.schedules import ( + DashboardEmailSchedule, EmailDeliveryType, + EmailSchedule, get_scheduler_model, ScheduleType, SliceEmailReportFormat, + SliceEmailSchedule, ) from superset.utils.core import get_email_address_list, send_email_smtp @@ -59,7 +65,9 @@ PAGE_RENDER_WAIT = 30 EmailContent = namedtuple("EmailContent", ["body", "data", "images"]) -def _get_recipients(schedule): +def _get_recipients( + schedule: Union[DashboardEmailSchedule, SliceEmailSchedule] +) -> Iterator[Tuple[str, str]]: bcc = config["EMAIL_REPORT_BCC_ADDRESS"] if schedule.deliver_as_group: @@ -70,7 +78,11 @@ def _get_recipients(schedule): yield (to, bcc) -def _deliver_email(schedule, subject, email): +def _deliver_email( + schedule: Union[DashboardEmailSchedule, SliceEmailSchedule], + subject: str, + email: EmailContent, +) -> None: for (to, bcc) in _get_recipients(schedule): send_email_smtp( to, @@ -85,7 +97,11 @@ def _deliver_email(schedule, subject, email): ) -def _generate_mail_content(schedule, screenshot, name, url): +def _generate_mail_content( + schedule: EmailSchedule, screenshot: bytes, name: str, url: str +) -> EmailContent: + data: Optional[Dict[str, Any]] + if schedule.delivery_type == EmailDeliveryType.attachment: images = None data = {"screenshot.png": screenshot} @@ -115,7 +131,7 @@ def _generate_mail_content(schedule, screenshot, name, url): return EmailContent(body, data, images) -def _get_auth_cookies(): +def _get_auth_cookies() -> List[TypeConversionDict]: # Login with the user specified to get the reports with app.test_request_context(): user = security_manager.find_user(config["EMAIL_REPORTS_USER"]) @@ -136,14 +152,16 @@ def _get_auth_cookies(): return cookies -def _get_url_path(view, **kwargs): +def _get_url_path(view: str, **kwargs: Any) -> str: with app.test_request_context(): return urllib.parse.urljoin( str(config["WEBDRIVER_BASEURL"]), url_for(view, **kwargs) ) -def create_webdriver(): +def create_webdriver() -> Union[ + chrome.webdriver.WebDriver, firefox.webdriver.WebDriver +]: # Create a webdriver for use in fetching reports if config["EMAIL_REPORTS_WEBDRIVER"] == "firefox": driver_class = firefox.webdriver.WebDriver @@ -181,7 +199,9 @@ def create_webdriver(): return driver -def destroy_webdriver(driver): +def destroy_webdriver( + driver: Union[chrome.webdriver.WebDriver, firefox.webdriver.WebDriver] +) -> None: """ Destroy a driver """ @@ -198,7 +218,7 @@ def destroy_webdriver(driver): pass -def deliver_dashboard(schedule): +def deliver_dashboard(schedule: DashboardEmailSchedule) -> None: """ Given a schedule, delivery the dashboard as an email report """ @@ -243,7 +263,7 @@ def deliver_dashboard(schedule): _deliver_email(schedule, subject, email) -def _get_slice_data(schedule): +def _get_slice_data(schedule: SliceEmailSchedule) -> EmailContent: slc = schedule.slice slice_url = _get_url_path( @@ -272,7 +292,7 @@ def _get_slice_data(schedule): # Parse the csv file and generate HTML columns = rows.pop(0) - with app.app_context(): + with app.app_context(): # type: ignore body = render_template( "superset/reports/slice_data.html", columns=columns, @@ -292,7 +312,7 @@ def _get_slice_data(schedule): return EmailContent(body, data, None) -def _get_slice_visualization(schedule): +def _get_slice_visualization(schedule: SliceEmailSchedule) -> EmailContent: slc = schedule.slice # Create a driver, fetch the page, wait for the page to render @@ -327,7 +347,7 @@ def _get_slice_visualization(schedule): return _generate_mail_content(schedule, screenshot, slc.slice_name, slice_url) -def deliver_slice(schedule): +def deliver_slice(schedule: Union[DashboardEmailSchedule, SliceEmailSchedule]) -> None: """ Given a schedule, delivery the slice as an email report """ @@ -352,9 +372,12 @@ def deliver_slice(schedule): bind=True, soft_time_limit=config["EMAIL_ASYNC_TIME_LIMIT_SEC"], ) -def schedule_email_report( - task, report_type, schedule_id, recipients=None -): # pylint: disable=unused-argument +def schedule_email_report( # pylint: disable=unused-argument + task: Task, + report_type: ScheduleType, + schedule_id: int, + recipients: Optional[str] = None, +) -> None: model_cls = get_scheduler_model(report_type) schedule = db.create_scoped_session().query(model_cls).get(schedule_id) @@ -368,15 +391,17 @@ def schedule_email_report( schedule.id = schedule_id schedule.recipients = recipients - if report_type == ScheduleType.dashboard.value: + if report_type == ScheduleType.dashboard: deliver_dashboard(schedule) - elif report_type == ScheduleType.slice.value: + elif report_type == ScheduleType.slice: deliver_slice(schedule) else: raise RuntimeError("Unknown report type") -def next_schedules(crontab, start_at, stop_at, resolution=0): +def next_schedules( + crontab: str, start_at: datetime, stop_at: datetime, resolution: int = 0 +) -> Iterator[datetime]: crons = croniter.croniter(crontab, start_at - timedelta(seconds=1)) previous = start_at - timedelta(days=1) @@ -396,13 +421,19 @@ def next_schedules(crontab, start_at, stop_at, resolution=0): previous = eta -def schedule_window(report_type, start_at, stop_at, resolution): +def schedule_window( + report_type: ScheduleType, start_at: datetime, stop_at: datetime, resolution: int +) -> None: """ Find all active schedules and schedule celery tasks for each of them with a specific ETA (determined by parsing the cron schedule for the schedule) """ model_cls = get_scheduler_model(report_type) + + if not model_cls: + return None + dbsession = db.create_scoped_session() schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True)) @@ -415,9 +446,11 @@ def schedule_window(report_type, start_at, stop_at, resolution): ): schedule_email_report.apply_async(args, eta=eta) + return None + @celery_app.task(name="email_reports.schedule_hourly") -def schedule_hourly(): +def schedule_hourly() -> None: """ Celery beat job meant to be invoked hourly """ if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]: @@ -429,5 +462,5 @@ def schedule_hourly(): # Get the top of the hour start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0, minute=0) stop_at = start_at + timedelta(seconds=3600) - schedule_window(ScheduleType.dashboard.value, start_at, stop_at, resolution) - schedule_window(ScheduleType.slice.value, start_at, stop_at, resolution) + schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution) + schedule_window(ScheduleType.slice, start_at, stop_at, resolution) diff --git a/superset/tasks/thumbnails.py b/superset/tasks/thumbnails.py index 72c7bdaf6..119770045 100644 --- a/superset/tasks/thumbnails.py +++ b/superset/tasks/thumbnails.py @@ -30,8 +30,8 @@ logger = logging.getLogger(__name__) @celery_app.task(name="cache_chart_thumbnail", soft_time_limit=300) -def cache_chart_thumbnail(chart_id: int, force: bool = False): - with app.app_context(): +def cache_chart_thumbnail(chart_id: int, force: bool = False) -> None: + with app.app_context(): # type: ignore if not thumbnail_cache: logger.warning("No cache set, refusing to compute") return None @@ -42,8 +42,8 @@ def cache_chart_thumbnail(chart_id: int, force: bool = False): @celery_app.task(name="cache_dashboard_thumbnail", soft_time_limit=300) -def cache_dashboard_thumbnail(dashboard_id: int, force: bool = False): - with app.app_context(): +def cache_dashboard_thumbnail(dashboard_id: int, force: bool = False) -> None: + with app.app_context(): # type: ignore if not thumbnail_cache: logging.warning("No cache set, refusing to compute") return None diff --git a/superset/utils/core.py b/superset/utils/core.py index ba715dde4..41deae504 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -235,7 +235,7 @@ def list_minus(l: List, minus: List) -> List: return [o for o in l if o not in minus] -def parse_human_datetime(s): +def parse_human_datetime(s: Optional[str]) -> Optional[datetime]: """ Returns ``datetime.datetime`` from human readable strings @@ -687,42 +687,42 @@ def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, conf def send_email_smtp( - to, - subject, - html_content, - config, - files=None, - data=None, - images=None, - dryrun=False, - cc=None, - bcc=None, - mime_subtype="mixed", -): + to: str, + subject: str, + html_content: str, + config: Dict[str, Any], + files: Optional[List[str]] = None, + data: Optional[Dict[str, str]] = None, + images: Optional[Dict[str, str]] = None, + dryrun: bool = False, + cc: Optional[str] = None, + bcc: Optional[str] = None, + mime_subtype: str = "mixed", +) -> None: """ Send an email with html content, eg: send_email_smtp( 'test@example.com', 'foo', 'Foo bar',['/dev/null'], dryrun=True) """ smtp_mail_from = config["SMTP_MAIL_FROM"] - to = get_email_address_list(to) + smtp_mail_to = get_email_address_list(to) msg = MIMEMultipart(mime_subtype) msg["Subject"] = subject msg["From"] = smtp_mail_from - msg["To"] = ", ".join(to) + msg["To"] = ", ".join(smtp_mail_to) msg.preamble = "This is a multi-part message in MIME format." - recipients = to + recipients = smtp_mail_to if cc: - cc = get_email_address_list(cc) - msg["CC"] = ", ".join(cc) - recipients = recipients + cc + smtp_mail_cc = get_email_address_list(cc) + msg["CC"] = ", ".join(smtp_mail_cc) + recipients = recipients + smtp_mail_cc if bcc: # don't add bcc in header - bcc = get_email_address_list(bcc) - recipients = recipients + bcc + smtp_mail_bcc = get_email_address_list(bcc) + recipients = recipients + smtp_mail_bcc msg["Date"] = formatdate(localtime=True) mime_text = MIMEText(html_content, "html") @@ -1034,8 +1034,8 @@ def get_since_until( """ separator = " : " - relative_start = parse_human_datetime(relative_start if relative_start else "today") - relative_end = parse_human_datetime(relative_end if relative_end else "today") + relative_start = parse_human_datetime(relative_start if relative_start else "today") # type: ignore + relative_end = parse_human_datetime(relative_end if relative_end else "today") # type: ignore common_time_frames = { "Last day": ( relative_start - relativedelta(days=1), # type: ignore @@ -1064,8 +1064,8 @@ def get_since_until( since, until = time_range.split(separator, 1) if since and since not in common_time_frames: since = add_ago_to_since(since) - since = parse_human_datetime(since) - until = parse_human_datetime(until) + since = parse_human_datetime(since) # type: ignore + until = parse_human_datetime(until) # type: ignore elif time_range in common_time_frames: since, until = common_time_frames[time_range] elif time_range == "No filter": @@ -1086,8 +1086,8 @@ def get_since_until( since = since or "" if since: since = add_ago_to_since(since) - since = parse_human_datetime(since) - until = parse_human_datetime(until) if until else relative_end + since = parse_human_datetime(since) # type: ignore + until = parse_human_datetime(until) if until else relative_end # type: ignore if time_shift: time_delta = parse_past_timedelta(time_shift)