[mypy] Enforcing typing for a number of modules (#9586)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
7d5f4494d0
commit
1c656feb95
|
|
@ -53,7 +53,7 @@ order_by_type = false
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
no_implicit_optional = true
|
no_implicit_optional = true
|
||||||
|
|
||||||
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*]
|
[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_untyped_calls = true
|
disallow_untyped_calls = true
|
||||||
disallow_untyped_defs = true
|
disallow_untyped_defs = true
|
||||||
|
|
|
||||||
|
|
@ -823,6 +823,7 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
|
|
||||||
if origin:
|
if origin:
|
||||||
dttm = utils.parse_human_datetime(origin)
|
dttm = utils.parse_human_datetime(origin)
|
||||||
|
assert dttm
|
||||||
granularity["origin"] = dttm.isoformat()
|
granularity["origin"] = dttm.isoformat()
|
||||||
|
|
||||||
if period_name in iso_8601_dict:
|
if period_name in iso_8601_dict:
|
||||||
|
|
@ -978,6 +979,7 @@ class DruidDatasource(Model, BaseDatasource):
|
||||||
# TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid
|
# TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid
|
||||||
if self.fetch_values_from:
|
if self.fetch_values_from:
|
||||||
from_dttm = utils.parse_human_datetime(self.fetch_values_from)
|
from_dttm = utils.parse_human_datetime(self.fetch_values_from)
|
||||||
|
assert from_dttm
|
||||||
else:
|
else:
|
||||||
from_dttm = datetime(1970, 1, 1)
|
from_dttm = datetime(1970, 1, 1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class SupersetTimeoutException(SupersetException):
|
||||||
class SupersetSecurityException(SupersetException):
|
class SupersetSecurityException(SupersetException):
|
||||||
status = 401
|
status = 401
|
||||||
|
|
||||||
def __init__(self, msg, link=None):
|
def __init__(self, msg: str, link: Optional[str] = None) -> None:
|
||||||
super(SupersetSecurityException, self).__init__(msg)
|
super(SupersetSecurityException, self).__init__(msg)
|
||||||
self.link = link
|
self.link = link
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
"""Models for scheduled execution of jobs"""
|
"""Models for scheduled execution of jobs"""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
from flask_appbuilder import Model
|
from flask_appbuilder import Model
|
||||||
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Text
|
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Text
|
||||||
|
|
@ -86,9 +86,9 @@ class SliceEmailSchedule(Model, AuditMixinNullable, ImportMixin, EmailSchedule):
|
||||||
email_format = Column(Enum(SliceEmailReportFormat))
|
email_format = Column(Enum(SliceEmailReportFormat))
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler_model(report_type):
|
def get_scheduler_model(report_type: ScheduleType) -> Optional[Type[EmailSchedule]]:
|
||||||
if report_type == ScheduleType.dashboard.value:
|
if report_type == ScheduleType.dashboard:
|
||||||
return DashboardEmailSchedule
|
return DashboardEmailSchedule
|
||||||
elif report_type == ScheduleType.slice.value:
|
elif report_type == ScheduleType.slice:
|
||||||
return SliceEmailSchedule
|
return SliceEmailSchedule
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
from sqlalchemy.engine.url import URL
|
||||||
|
|
||||||
|
|
||||||
class DBSecurityException(Exception):
|
class DBSecurityException(Exception):
|
||||||
|
|
@ -22,7 +23,7 @@ class DBSecurityException(Exception):
|
||||||
status = 400
|
status = 400
|
||||||
|
|
||||||
|
|
||||||
def check_sqlalchemy_uri(uri):
|
def check_sqlalchemy_uri(uri: URL) -> None:
|
||||||
if uri.startswith("sqlite"):
|
if uri.startswith("sqlite"):
|
||||||
# sqlite creates a local DB, which allows mapping server's filesystem
|
# sqlite creates a local DB, which allows mapping server's filesystem
|
||||||
raise DBSecurityException(
|
raise DBSecurityException(
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,7 @@ from flask_appbuilder.widgets import ListWidget
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_
|
||||||
from sqlalchemy.engine.base import Connection
|
from sqlalchemy.engine.base import Connection
|
||||||
from sqlalchemy.orm.mapper import Mapper
|
from sqlalchemy.orm.mapper import Mapper
|
||||||
|
from sqlalchemy.orm.query import Query
|
||||||
|
|
||||||
from superset import sql_parse
|
from superset import sql_parse
|
||||||
from superset.connectors.connector_registry import ConnectorRegistry
|
from superset.connectors.connector_registry import ConnectorRegistry
|
||||||
|
|
@ -70,7 +71,7 @@ class SupersetRoleListWidget(ListWidget):
|
||||||
|
|
||||||
template = "superset/fab_overrides/list_role.html"
|
template = "superset/fab_overrides/list_role.html"
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
kwargs["appbuilder"] = current_app.appbuilder
|
kwargs["appbuilder"] = current_app.appbuilder
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
@ -580,7 +581,7 @@ class SupersetSecurityManager(SecurityManager):
|
||||||
if pv.permission and pv.view_menu:
|
if pv.permission and pv.view_menu:
|
||||||
all_pvs.add((pv.permission.name, pv.view_menu.name))
|
all_pvs.add((pv.permission.name, pv.view_menu.name))
|
||||||
|
|
||||||
def merge_pv(view_menu, perm):
|
def merge_pv(view_menu: str, perm: str) -> None:
|
||||||
"""Create permission view menu only if it doesn't exist"""
|
"""Create permission view menu only if it doesn't exist"""
|
||||||
if view_menu and perm and (view_menu, perm) not in all_pvs:
|
if view_menu and perm and (view_menu, perm) not in all_pvs:
|
||||||
self.add_permission_view_menu(view_menu, perm)
|
self.add_permission_view_menu(view_menu, perm)
|
||||||
|
|
@ -899,7 +900,7 @@ class SupersetSecurityManager(SecurityManager):
|
||||||
|
|
||||||
self.assert_datasource_permission(viz.datasource)
|
self.assert_datasource_permission(viz.datasource)
|
||||||
|
|
||||||
def get_rls_filters(self, table: "BaseDatasource"):
|
def get_rls_filters(self, table: "BaseDatasource") -> List[Query]:
|
||||||
"""
|
"""
|
||||||
Retrieves the appropriate row level security filters for the current user and the passed table.
|
Retrieves the appropriate row level security filters for the current user and the passed table.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from typing import Any, Dict, List, Optional
|
||||||
from flask import g
|
from flask import g
|
||||||
|
|
||||||
from superset import app, security_manager
|
from superset import app, security_manager
|
||||||
|
from superset.models.core import Database
|
||||||
from superset.sql_parse import ParsedQuery
|
from superset.sql_parse import ParsedQuery
|
||||||
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
|
from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
|
||||||
from superset.utils.core import QuerySource
|
from superset.utils.core import QuerySource
|
||||||
|
|
@ -44,7 +45,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_statement(
|
def validate_statement(
|
||||||
cls, statement, database, cursor, user_name
|
cls, statement: str, database: Database, cursor: Any, user_name: str
|
||||||
) -> Optional[SQLValidationAnnotation]:
|
) -> Optional[SQLValidationAnnotation]:
|
||||||
# pylint: disable=too-many-locals
|
# pylint: disable=too-many-locals
|
||||||
db_engine_spec = database.db_engine_spec
|
db_engine_spec = database.db_engine_spec
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from urllib import request
|
from urllib import request
|
||||||
from urllib.error import URLError
|
from urllib.error import URLError
|
||||||
|
|
||||||
|
|
@ -38,7 +38,9 @@ logger = get_task_logger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def get_form_data(chart_id, dashboard=None):
|
def get_form_data(
|
||||||
|
chart_id: int, dashboard: Optional[Dashboard] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Build `form_data` for chart GET request from dashboard's `default_filters`.
|
Build `form_data` for chart GET request from dashboard's `default_filters`.
|
||||||
|
|
||||||
|
|
@ -46,7 +48,7 @@ def get_form_data(chart_id, dashboard=None):
|
||||||
filters in the GET request for charts.
|
filters in the GET request for charts.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
form_data = {"slice_id": chart_id}
|
form_data: Dict[str, Any] = {"slice_id": chart_id}
|
||||||
|
|
||||||
if dashboard is None or not dashboard.json_metadata:
|
if dashboard is None or not dashboard.json_metadata:
|
||||||
return form_data
|
return form_data
|
||||||
|
|
@ -72,7 +74,7 @@ def get_form_data(chart_id, dashboard=None):
|
||||||
return form_data
|
return form_data
|
||||||
|
|
||||||
|
|
||||||
def get_url(chart, extra_filters: Optional[Dict[str, Any]] = None):
|
def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str:
|
||||||
"""Return external URL for warming up a given chart/table cache."""
|
"""Return external URL for warming up a given chart/table cache."""
|
||||||
with app.test_request_context():
|
with app.test_request_context():
|
||||||
baseurl = (
|
baseurl = (
|
||||||
|
|
@ -106,10 +108,10 @@ class Strategy:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_urls(self):
|
def get_urls(self) -> List[str]:
|
||||||
raise NotImplementedError("Subclasses must implement get_urls!")
|
raise NotImplementedError("Subclasses must implement get_urls!")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -131,7 +133,7 @@ class DummyStrategy(Strategy):
|
||||||
|
|
||||||
name = "dummy"
|
name = "dummy"
|
||||||
|
|
||||||
def get_urls(self):
|
def get_urls(self) -> List[str]:
|
||||||
session = db.create_scoped_session()
|
session = db.create_scoped_session()
|
||||||
charts = session.query(Slice).all()
|
charts = session.query(Slice).all()
|
||||||
|
|
||||||
|
|
@ -158,12 +160,12 @@ class TopNDashboardsStrategy(Strategy):
|
||||||
|
|
||||||
name = "top_n_dashboards"
|
name = "top_n_dashboards"
|
||||||
|
|
||||||
def __init__(self, top_n=5, since="7 days ago"):
|
def __init__(self, top_n: int = 5, since: str = "7 days ago") -> None:
|
||||||
super(TopNDashboardsStrategy, self).__init__()
|
super(TopNDashboardsStrategy, self).__init__()
|
||||||
self.top_n = top_n
|
self.top_n = top_n
|
||||||
self.since = parse_human_datetime(since)
|
self.since = parse_human_datetime(since)
|
||||||
|
|
||||||
def get_urls(self):
|
def get_urls(self) -> List[str]:
|
||||||
urls = []
|
urls = []
|
||||||
session = db.create_scoped_session()
|
session = db.create_scoped_session()
|
||||||
|
|
||||||
|
|
@ -203,11 +205,11 @@ class DashboardTagsStrategy(Strategy):
|
||||||
|
|
||||||
name = "dashboard_tags"
|
name = "dashboard_tags"
|
||||||
|
|
||||||
def __init__(self, tags=None):
|
def __init__(self, tags: Optional[List[str]] = None) -> None:
|
||||||
super(DashboardTagsStrategy, self).__init__()
|
super(DashboardTagsStrategy, self).__init__()
|
||||||
self.tags = tags or []
|
self.tags = tags or []
|
||||||
|
|
||||||
def get_urls(self):
|
def get_urls(self) -> List[str]:
|
||||||
urls = []
|
urls = []
|
||||||
session = db.create_scoped_session()
|
session = db.create_scoped_session()
|
||||||
|
|
||||||
|
|
@ -254,7 +256,9 @@ strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="cache-warmup")
|
@celery_app.task(name="cache-warmup")
|
||||||
def cache_warmup(strategy_name, *args, **kwargs):
|
def cache_warmup(
|
||||||
|
strategy_name: str, *args: Any, **kwargs: Any
|
||||||
|
) -> Union[Dict[str, List[str]], str]:
|
||||||
"""
|
"""
|
||||||
Warm up cache.
|
Warm up cache.
|
||||||
|
|
||||||
|
|
@ -264,7 +268,7 @@ def cache_warmup(strategy_name, *args, **kwargs):
|
||||||
logger.info("Loading strategy")
|
logger.info("Loading strategy")
|
||||||
class_ = None
|
class_ = None
|
||||||
for class_ in strategies:
|
for class_ in strategies:
|
||||||
if class_.name == strategy_name:
|
if class_.name == strategy_name: # type: ignore
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
message = f"No strategy {strategy_name} found!"
|
message = f"No strategy {strategy_name} found!"
|
||||||
|
|
@ -280,7 +284,7 @@ def cache_warmup(strategy_name, *args, **kwargs):
|
||||||
logger.exception(message)
|
logger.exception(message)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
results = {"success": [], "errors": []}
|
results: Dict[str, List[str]] = {"success": [], "errors": []}
|
||||||
for url in strategy.get_urls():
|
for url in strategy.get_urls():
|
||||||
try:
|
try:
|
||||||
logger.info(f"Fetching {url}")
|
logger.info(f"Fetching {url}")
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from superset import create_app
|
||||||
from superset.extensions import celery_app
|
from superset.extensions import celery_app
|
||||||
|
|
||||||
# Init the Flask app / configure everything
|
# Init the Flask app / configure everything
|
||||||
create_app()
|
create_app() # type: ignore
|
||||||
|
|
||||||
# Need to import late, as the celery_app will have been setup by "create_app()"
|
# Need to import late, as the celery_app will have been setup by "create_app()"
|
||||||
# pylint: disable=wrong-import-position, unused-import
|
# pylint: disable=wrong-import-position, unused-import
|
||||||
|
|
|
||||||
|
|
@ -23,10 +23,12 @@ import urllib.request
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from email.utils import make_msgid, parseaddr
|
from email.utils import make_msgid, parseaddr
|
||||||
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
from urllib.error import URLError # pylint: disable=ungrouped-imports
|
from urllib.error import URLError # pylint: disable=ungrouped-imports
|
||||||
|
|
||||||
import croniter
|
import croniter
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
|
from celery.app.task import Task
|
||||||
from dateutil.tz import tzlocal
|
from dateutil.tz import tzlocal
|
||||||
from flask import render_template, Response, session, url_for
|
from flask import render_template, Response, session, url_for
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
|
|
@ -34,16 +36,20 @@ from flask_login import login_user
|
||||||
from retry.api import retry_call
|
from retry.api import retry_call
|
||||||
from selenium.common.exceptions import WebDriverException
|
from selenium.common.exceptions import WebDriverException
|
||||||
from selenium.webdriver import chrome, firefox
|
from selenium.webdriver import chrome, firefox
|
||||||
|
from werkzeug.datastructures import TypeConversionDict
|
||||||
from werkzeug.http import parse_cookie
|
from werkzeug.http import parse_cookie
|
||||||
|
|
||||||
# Superset framework imports
|
# Superset framework imports
|
||||||
from superset import app, db, security_manager
|
from superset import app, db, security_manager
|
||||||
from superset.extensions import celery_app
|
from superset.extensions import celery_app
|
||||||
from superset.models.schedules import (
|
from superset.models.schedules import (
|
||||||
|
DashboardEmailSchedule,
|
||||||
EmailDeliveryType,
|
EmailDeliveryType,
|
||||||
|
EmailSchedule,
|
||||||
get_scheduler_model,
|
get_scheduler_model,
|
||||||
ScheduleType,
|
ScheduleType,
|
||||||
SliceEmailReportFormat,
|
SliceEmailReportFormat,
|
||||||
|
SliceEmailSchedule,
|
||||||
)
|
)
|
||||||
from superset.utils.core import get_email_address_list, send_email_smtp
|
from superset.utils.core import get_email_address_list, send_email_smtp
|
||||||
|
|
||||||
|
|
@ -59,7 +65,9 @@ PAGE_RENDER_WAIT = 30
|
||||||
EmailContent = namedtuple("EmailContent", ["body", "data", "images"])
|
EmailContent = namedtuple("EmailContent", ["body", "data", "images"])
|
||||||
|
|
||||||
|
|
||||||
def _get_recipients(schedule):
|
def _get_recipients(
|
||||||
|
schedule: Union[DashboardEmailSchedule, SliceEmailSchedule]
|
||||||
|
) -> Iterator[Tuple[str, str]]:
|
||||||
bcc = config["EMAIL_REPORT_BCC_ADDRESS"]
|
bcc = config["EMAIL_REPORT_BCC_ADDRESS"]
|
||||||
|
|
||||||
if schedule.deliver_as_group:
|
if schedule.deliver_as_group:
|
||||||
|
|
@ -70,7 +78,11 @@ def _get_recipients(schedule):
|
||||||
yield (to, bcc)
|
yield (to, bcc)
|
||||||
|
|
||||||
|
|
||||||
def _deliver_email(schedule, subject, email):
|
def _deliver_email(
|
||||||
|
schedule: Union[DashboardEmailSchedule, SliceEmailSchedule],
|
||||||
|
subject: str,
|
||||||
|
email: EmailContent,
|
||||||
|
) -> None:
|
||||||
for (to, bcc) in _get_recipients(schedule):
|
for (to, bcc) in _get_recipients(schedule):
|
||||||
send_email_smtp(
|
send_email_smtp(
|
||||||
to,
|
to,
|
||||||
|
|
@ -85,7 +97,11 @@ def _deliver_email(schedule, subject, email):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_mail_content(schedule, screenshot, name, url):
|
def _generate_mail_content(
|
||||||
|
schedule: EmailSchedule, screenshot: bytes, name: str, url: str
|
||||||
|
) -> EmailContent:
|
||||||
|
data: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
if schedule.delivery_type == EmailDeliveryType.attachment:
|
if schedule.delivery_type == EmailDeliveryType.attachment:
|
||||||
images = None
|
images = None
|
||||||
data = {"screenshot.png": screenshot}
|
data = {"screenshot.png": screenshot}
|
||||||
|
|
@ -115,7 +131,7 @@ def _generate_mail_content(schedule, screenshot, name, url):
|
||||||
return EmailContent(body, data, images)
|
return EmailContent(body, data, images)
|
||||||
|
|
||||||
|
|
||||||
def _get_auth_cookies():
|
def _get_auth_cookies() -> List[TypeConversionDict]:
|
||||||
# Login with the user specified to get the reports
|
# Login with the user specified to get the reports
|
||||||
with app.test_request_context():
|
with app.test_request_context():
|
||||||
user = security_manager.find_user(config["EMAIL_REPORTS_USER"])
|
user = security_manager.find_user(config["EMAIL_REPORTS_USER"])
|
||||||
|
|
@ -136,14 +152,16 @@ def _get_auth_cookies():
|
||||||
return cookies
|
return cookies
|
||||||
|
|
||||||
|
|
||||||
def _get_url_path(view, **kwargs):
|
def _get_url_path(view: str, **kwargs: Any) -> str:
|
||||||
with app.test_request_context():
|
with app.test_request_context():
|
||||||
return urllib.parse.urljoin(
|
return urllib.parse.urljoin(
|
||||||
str(config["WEBDRIVER_BASEURL"]), url_for(view, **kwargs)
|
str(config["WEBDRIVER_BASEURL"]), url_for(view, **kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_webdriver():
|
def create_webdriver() -> Union[
|
||||||
|
chrome.webdriver.WebDriver, firefox.webdriver.WebDriver
|
||||||
|
]:
|
||||||
# Create a webdriver for use in fetching reports
|
# Create a webdriver for use in fetching reports
|
||||||
if config["EMAIL_REPORTS_WEBDRIVER"] == "firefox":
|
if config["EMAIL_REPORTS_WEBDRIVER"] == "firefox":
|
||||||
driver_class = firefox.webdriver.WebDriver
|
driver_class = firefox.webdriver.WebDriver
|
||||||
|
|
@ -181,7 +199,9 @@ def create_webdriver():
|
||||||
return driver
|
return driver
|
||||||
|
|
||||||
|
|
||||||
def destroy_webdriver(driver):
|
def destroy_webdriver(
|
||||||
|
driver: Union[chrome.webdriver.WebDriver, firefox.webdriver.WebDriver]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Destroy a driver
|
Destroy a driver
|
||||||
"""
|
"""
|
||||||
|
|
@ -198,7 +218,7 @@ def destroy_webdriver(driver):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def deliver_dashboard(schedule):
|
def deliver_dashboard(schedule: DashboardEmailSchedule) -> None:
|
||||||
"""
|
"""
|
||||||
Given a schedule, delivery the dashboard as an email report
|
Given a schedule, delivery the dashboard as an email report
|
||||||
"""
|
"""
|
||||||
|
|
@ -243,7 +263,7 @@ def deliver_dashboard(schedule):
|
||||||
_deliver_email(schedule, subject, email)
|
_deliver_email(schedule, subject, email)
|
||||||
|
|
||||||
|
|
||||||
def _get_slice_data(schedule):
|
def _get_slice_data(schedule: SliceEmailSchedule) -> EmailContent:
|
||||||
slc = schedule.slice
|
slc = schedule.slice
|
||||||
|
|
||||||
slice_url = _get_url_path(
|
slice_url = _get_url_path(
|
||||||
|
|
@ -272,7 +292,7 @@ def _get_slice_data(schedule):
|
||||||
|
|
||||||
# Parse the csv file and generate HTML
|
# Parse the csv file and generate HTML
|
||||||
columns = rows.pop(0)
|
columns = rows.pop(0)
|
||||||
with app.app_context():
|
with app.app_context(): # type: ignore
|
||||||
body = render_template(
|
body = render_template(
|
||||||
"superset/reports/slice_data.html",
|
"superset/reports/slice_data.html",
|
||||||
columns=columns,
|
columns=columns,
|
||||||
|
|
@ -292,7 +312,7 @@ def _get_slice_data(schedule):
|
||||||
return EmailContent(body, data, None)
|
return EmailContent(body, data, None)
|
||||||
|
|
||||||
|
|
||||||
def _get_slice_visualization(schedule):
|
def _get_slice_visualization(schedule: SliceEmailSchedule) -> EmailContent:
|
||||||
slc = schedule.slice
|
slc = schedule.slice
|
||||||
|
|
||||||
# Create a driver, fetch the page, wait for the page to render
|
# Create a driver, fetch the page, wait for the page to render
|
||||||
|
|
@ -327,7 +347,7 @@ def _get_slice_visualization(schedule):
|
||||||
return _generate_mail_content(schedule, screenshot, slc.slice_name, slice_url)
|
return _generate_mail_content(schedule, screenshot, slc.slice_name, slice_url)
|
||||||
|
|
||||||
|
|
||||||
def deliver_slice(schedule):
|
def deliver_slice(schedule: Union[DashboardEmailSchedule, SliceEmailSchedule]) -> None:
|
||||||
"""
|
"""
|
||||||
Given a schedule, delivery the slice as an email report
|
Given a schedule, delivery the slice as an email report
|
||||||
"""
|
"""
|
||||||
|
|
@ -352,9 +372,12 @@ def deliver_slice(schedule):
|
||||||
bind=True,
|
bind=True,
|
||||||
soft_time_limit=config["EMAIL_ASYNC_TIME_LIMIT_SEC"],
|
soft_time_limit=config["EMAIL_ASYNC_TIME_LIMIT_SEC"],
|
||||||
)
|
)
|
||||||
def schedule_email_report(
|
def schedule_email_report( # pylint: disable=unused-argument
|
||||||
task, report_type, schedule_id, recipients=None
|
task: Task,
|
||||||
): # pylint: disable=unused-argument
|
report_type: ScheduleType,
|
||||||
|
schedule_id: int,
|
||||||
|
recipients: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
model_cls = get_scheduler_model(report_type)
|
model_cls = get_scheduler_model(report_type)
|
||||||
schedule = db.create_scoped_session().query(model_cls).get(schedule_id)
|
schedule = db.create_scoped_session().query(model_cls).get(schedule_id)
|
||||||
|
|
||||||
|
|
@ -368,15 +391,17 @@ def schedule_email_report(
|
||||||
schedule.id = schedule_id
|
schedule.id = schedule_id
|
||||||
schedule.recipients = recipients
|
schedule.recipients = recipients
|
||||||
|
|
||||||
if report_type == ScheduleType.dashboard.value:
|
if report_type == ScheduleType.dashboard:
|
||||||
deliver_dashboard(schedule)
|
deliver_dashboard(schedule)
|
||||||
elif report_type == ScheduleType.slice.value:
|
elif report_type == ScheduleType.slice:
|
||||||
deliver_slice(schedule)
|
deliver_slice(schedule)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown report type")
|
raise RuntimeError("Unknown report type")
|
||||||
|
|
||||||
|
|
||||||
def next_schedules(crontab, start_at, stop_at, resolution=0):
|
def next_schedules(
|
||||||
|
crontab: str, start_at: datetime, stop_at: datetime, resolution: int = 0
|
||||||
|
) -> Iterator[datetime]:
|
||||||
crons = croniter.croniter(crontab, start_at - timedelta(seconds=1))
|
crons = croniter.croniter(crontab, start_at - timedelta(seconds=1))
|
||||||
previous = start_at - timedelta(days=1)
|
previous = start_at - timedelta(days=1)
|
||||||
|
|
||||||
|
|
@ -396,13 +421,19 @@ def next_schedules(crontab, start_at, stop_at, resolution=0):
|
||||||
previous = eta
|
previous = eta
|
||||||
|
|
||||||
|
|
||||||
def schedule_window(report_type, start_at, stop_at, resolution):
|
def schedule_window(
|
||||||
|
report_type: ScheduleType, start_at: datetime, stop_at: datetime, resolution: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Find all active schedules and schedule celery tasks for
|
Find all active schedules and schedule celery tasks for
|
||||||
each of them with a specific ETA (determined by parsing
|
each of them with a specific ETA (determined by parsing
|
||||||
the cron schedule for the schedule)
|
the cron schedule for the schedule)
|
||||||
"""
|
"""
|
||||||
model_cls = get_scheduler_model(report_type)
|
model_cls = get_scheduler_model(report_type)
|
||||||
|
|
||||||
|
if not model_cls:
|
||||||
|
return None
|
||||||
|
|
||||||
dbsession = db.create_scoped_session()
|
dbsession = db.create_scoped_session()
|
||||||
schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
|
schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
|
||||||
|
|
||||||
|
|
@ -415,9 +446,11 @@ def schedule_window(report_type, start_at, stop_at, resolution):
|
||||||
):
|
):
|
||||||
schedule_email_report.apply_async(args, eta=eta)
|
schedule_email_report.apply_async(args, eta=eta)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="email_reports.schedule_hourly")
|
@celery_app.task(name="email_reports.schedule_hourly")
|
||||||
def schedule_hourly():
|
def schedule_hourly() -> None:
|
||||||
""" Celery beat job meant to be invoked hourly """
|
""" Celery beat job meant to be invoked hourly """
|
||||||
|
|
||||||
if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]:
|
if not config["ENABLE_SCHEDULED_EMAIL_REPORTS"]:
|
||||||
|
|
@ -429,5 +462,5 @@ def schedule_hourly():
|
||||||
# Get the top of the hour
|
# Get the top of the hour
|
||||||
start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0, minute=0)
|
start_at = datetime.now(tzlocal()).replace(microsecond=0, second=0, minute=0)
|
||||||
stop_at = start_at + timedelta(seconds=3600)
|
stop_at = start_at + timedelta(seconds=3600)
|
||||||
schedule_window(ScheduleType.dashboard.value, start_at, stop_at, resolution)
|
schedule_window(ScheduleType.dashboard, start_at, stop_at, resolution)
|
||||||
schedule_window(ScheduleType.slice.value, start_at, stop_at, resolution)
|
schedule_window(ScheduleType.slice, start_at, stop_at, resolution)
|
||||||
|
|
|
||||||
|
|
@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="cache_chart_thumbnail", soft_time_limit=300)
|
@celery_app.task(name="cache_chart_thumbnail", soft_time_limit=300)
|
||||||
def cache_chart_thumbnail(chart_id: int, force: bool = False):
|
def cache_chart_thumbnail(chart_id: int, force: bool = False) -> None:
|
||||||
with app.app_context():
|
with app.app_context(): # type: ignore
|
||||||
if not thumbnail_cache:
|
if not thumbnail_cache:
|
||||||
logger.warning("No cache set, refusing to compute")
|
logger.warning("No cache set, refusing to compute")
|
||||||
return None
|
return None
|
||||||
|
|
@ -42,8 +42,8 @@ def cache_chart_thumbnail(chart_id: int, force: bool = False):
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="cache_dashboard_thumbnail", soft_time_limit=300)
|
@celery_app.task(name="cache_dashboard_thumbnail", soft_time_limit=300)
|
||||||
def cache_dashboard_thumbnail(dashboard_id: int, force: bool = False):
|
def cache_dashboard_thumbnail(dashboard_id: int, force: bool = False) -> None:
|
||||||
with app.app_context():
|
with app.app_context(): # type: ignore
|
||||||
if not thumbnail_cache:
|
if not thumbnail_cache:
|
||||||
logging.warning("No cache set, refusing to compute")
|
logging.warning("No cache set, refusing to compute")
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -235,7 +235,7 @@ def list_minus(l: List, minus: List) -> List:
|
||||||
return [o for o in l if o not in minus]
|
return [o for o in l if o not in minus]
|
||||||
|
|
||||||
|
|
||||||
def parse_human_datetime(s):
|
def parse_human_datetime(s: Optional[str]) -> Optional[datetime]:
|
||||||
"""
|
"""
|
||||||
Returns ``datetime.datetime`` from human readable strings
|
Returns ``datetime.datetime`` from human readable strings
|
||||||
|
|
||||||
|
|
@ -687,42 +687,42 @@ def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, conf
|
||||||
|
|
||||||
|
|
||||||
def send_email_smtp(
|
def send_email_smtp(
|
||||||
to,
|
to: str,
|
||||||
subject,
|
subject: str,
|
||||||
html_content,
|
html_content: str,
|
||||||
config,
|
config: Dict[str, Any],
|
||||||
files=None,
|
files: Optional[List[str]] = None,
|
||||||
data=None,
|
data: Optional[Dict[str, str]] = None,
|
||||||
images=None,
|
images: Optional[Dict[str, str]] = None,
|
||||||
dryrun=False,
|
dryrun: bool = False,
|
||||||
cc=None,
|
cc: Optional[str] = None,
|
||||||
bcc=None,
|
bcc: Optional[str] = None,
|
||||||
mime_subtype="mixed",
|
mime_subtype: str = "mixed",
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Send an email with html content, eg:
|
Send an email with html content, eg:
|
||||||
send_email_smtp(
|
send_email_smtp(
|
||||||
'test@example.com', 'foo', '<b>Foo</b> bar',['/dev/null'], dryrun=True)
|
'test@example.com', 'foo', '<b>Foo</b> bar',['/dev/null'], dryrun=True)
|
||||||
"""
|
"""
|
||||||
smtp_mail_from = config["SMTP_MAIL_FROM"]
|
smtp_mail_from = config["SMTP_MAIL_FROM"]
|
||||||
to = get_email_address_list(to)
|
smtp_mail_to = get_email_address_list(to)
|
||||||
|
|
||||||
msg = MIMEMultipart(mime_subtype)
|
msg = MIMEMultipart(mime_subtype)
|
||||||
msg["Subject"] = subject
|
msg["Subject"] = subject
|
||||||
msg["From"] = smtp_mail_from
|
msg["From"] = smtp_mail_from
|
||||||
msg["To"] = ", ".join(to)
|
msg["To"] = ", ".join(smtp_mail_to)
|
||||||
msg.preamble = "This is a multi-part message in MIME format."
|
msg.preamble = "This is a multi-part message in MIME format."
|
||||||
|
|
||||||
recipients = to
|
recipients = smtp_mail_to
|
||||||
if cc:
|
if cc:
|
||||||
cc = get_email_address_list(cc)
|
smtp_mail_cc = get_email_address_list(cc)
|
||||||
msg["CC"] = ", ".join(cc)
|
msg["CC"] = ", ".join(smtp_mail_cc)
|
||||||
recipients = recipients + cc
|
recipients = recipients + smtp_mail_cc
|
||||||
|
|
||||||
if bcc:
|
if bcc:
|
||||||
# don't add bcc in header
|
# don't add bcc in header
|
||||||
bcc = get_email_address_list(bcc)
|
smtp_mail_bcc = get_email_address_list(bcc)
|
||||||
recipients = recipients + bcc
|
recipients = recipients + smtp_mail_bcc
|
||||||
|
|
||||||
msg["Date"] = formatdate(localtime=True)
|
msg["Date"] = formatdate(localtime=True)
|
||||||
mime_text = MIMEText(html_content, "html")
|
mime_text = MIMEText(html_content, "html")
|
||||||
|
|
@ -1034,8 +1034,8 @@ def get_since_until(
|
||||||
|
|
||||||
"""
|
"""
|
||||||
separator = " : "
|
separator = " : "
|
||||||
relative_start = parse_human_datetime(relative_start if relative_start else "today")
|
relative_start = parse_human_datetime(relative_start if relative_start else "today") # type: ignore
|
||||||
relative_end = parse_human_datetime(relative_end if relative_end else "today")
|
relative_end = parse_human_datetime(relative_end if relative_end else "today") # type: ignore
|
||||||
common_time_frames = {
|
common_time_frames = {
|
||||||
"Last day": (
|
"Last day": (
|
||||||
relative_start - relativedelta(days=1), # type: ignore
|
relative_start - relativedelta(days=1), # type: ignore
|
||||||
|
|
@ -1064,8 +1064,8 @@ def get_since_until(
|
||||||
since, until = time_range.split(separator, 1)
|
since, until = time_range.split(separator, 1)
|
||||||
if since and since not in common_time_frames:
|
if since and since not in common_time_frames:
|
||||||
since = add_ago_to_since(since)
|
since = add_ago_to_since(since)
|
||||||
since = parse_human_datetime(since)
|
since = parse_human_datetime(since) # type: ignore
|
||||||
until = parse_human_datetime(until)
|
until = parse_human_datetime(until) # type: ignore
|
||||||
elif time_range in common_time_frames:
|
elif time_range in common_time_frames:
|
||||||
since, until = common_time_frames[time_range]
|
since, until = common_time_frames[time_range]
|
||||||
elif time_range == "No filter":
|
elif time_range == "No filter":
|
||||||
|
|
@ -1086,8 +1086,8 @@ def get_since_until(
|
||||||
since = since or ""
|
since = since or ""
|
||||||
if since:
|
if since:
|
||||||
since = add_ago_to_since(since)
|
since = add_ago_to_since(since)
|
||||||
since = parse_human_datetime(since)
|
since = parse_human_datetime(since) # type: ignore
|
||||||
until = parse_human_datetime(until) if until else relative_end
|
until = parse_human_datetime(until) if until else relative_end # type: ignore
|
||||||
|
|
||||||
if time_shift:
|
if time_shift:
|
||||||
time_delta = parse_past_timedelta(time_shift)
|
time_delta = parse_past_timedelta(time_shift)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue