[mypy] Enforcing typing for superset.utils (#9905)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-05-27 22:57:30 -07:00 committed by GitHub
parent 54dced1cf6
commit b296a0f250
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 194 additions and 148 deletions

View File

@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true ignore_missing_imports = true
no_implicit_optional = true no_implicit_optional = true
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*] [mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
check_untyped_defs = true check_untyped_defs = true
disallow_untyped_calls = true disallow_untyped_calls = true
disallow_untyped_defs = true disallow_untyped_defs = true

View File

@ -279,7 +279,7 @@ LANGUAGES = {
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here # For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py # and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True } # will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
DEFAULT_FEATURE_FLAGS = { DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
# Experimental feature introducing a client (browser) cache # Experimental feature introducing a client (browser) cache
"CLIENT_CACHE": False, "CLIENT_CACHE": False,
"ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False, "ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False,

View File

@ -28,6 +28,7 @@ DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, .
DbapiResult = List[Union[List[Any], Tuple[Any, ...]]] DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
FilterValue = Union[float, int, str] FilterValue = Union[float, int, str]
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]] FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
FormData = Dict[str, Any]
Granularity = Union[str, Dict[str, Union[str, float]]] Granularity = Union[str, Dict[str, Union[str, float]]]
Metric = Union[Dict[str, str], str] Metric = Union[Dict[str, str], str]
QueryObjectDict = Dict[str, Any] QueryObjectDict = Dict[str, Any]

View File

@ -14,14 +14,14 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Callable, Optional from typing import Any, Callable, Optional
from flask import request from flask import request
from superset.extensions import cache_manager from superset.extensions import cache_manager
def view_cache_key(*_, **__) -> str: def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-argument
args_hash = hash(frozenset(request.args.items())) args_hash = hash(frozenset(request.args.items()))
return "view/{}/{}".format(request.path, args_hash) return "view/{}/{}".format(request.path, args_hash)
@ -45,10 +45,10 @@ def memoized_func(
returns the caching key. returns the caching key.
""" """
def wrap(f): def wrap(f: Callable) -> Callable:
if cache_manager.tables_cache: if cache_manager.tables_cache:
def wrapped_f(self, *args, **kwargs): def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
if not kwargs.get("cache", True): if not kwargs.get("cache", True):
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
@ -69,7 +69,7 @@ def memoized_func(
else: else:
# noop # noop
def wrapped_f(self, *args, **kwargs): def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
return wrapped_f return wrapped_f

View File

@ -39,6 +39,7 @@ from email.utils import formatdate
from enum import Enum from enum import Enum
from time import struct_time from time import struct_time
from timeit import default_timer from timeit import default_timer
from types import TracebackType
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -51,6 +52,7 @@ from typing import (
Sequence, Sequence,
Set, Set,
Tuple, Tuple,
Type,
TYPE_CHECKING, TYPE_CHECKING,
Union, Union,
) )
@ -69,10 +71,12 @@ from dateutil.parser import parse
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from flask import current_app, flash, g, Markup, render_template from flask import current_app, flash, g, Markup, render_template
from flask_appbuilder import SQLA from flask_appbuilder import SQLA
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import Role, User
from flask_babel import gettext as __, lazy_gettext as _ from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import event, exc, select, Text from sqlalchemy import event, exc, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.type_api import Variant from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator from sqlalchemy.types import TEXT, TypeDecorator
@ -81,7 +85,7 @@ from superset.exceptions import (
SupersetException, SupersetException,
SupersetTimeoutException, SupersetTimeoutException,
) )
from superset.typing import Metric from superset.typing import FormData, Metric
from superset.utils.dates import datetime_to_epoch, EPOCH from superset.utils.dates import datetime_to_epoch, EPOCH
try: try:
@ -90,6 +94,7 @@ except ImportError:
pass pass
if TYPE_CHECKING: if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
from superset.models.core import Database from superset.models.core import Database
@ -121,7 +126,7 @@ except NameError:
pass pass
def flasher(msg: str, severity: str) -> None: def flasher(msg: str, severity: str = "message") -> None:
"""Flask's flash if available, logging call if not""" """Flask's flash if available, logging call if not"""
try: try:
flash(msg, severity) flash(msg, severity)
@ -142,17 +147,17 @@ class _memoized:
should account for instance variable changes. should account for instance variable changes.
""" """
def __init__(self, func, watch=()): def __init__(self, func: Callable, watch: Optional[List[str]] = None) -> None:
self.func = func self.func = func
self.cache = {} self.cache: Dict[Any, Any] = {}
self.is_method = False self.is_method = False
self.watch = watch or [] self.watch = watch or []
def __call__(self, *args, **kwargs): def __call__(self, *args: Any, **kwargs: Any) -> Any:
key = [args, frozenset(kwargs.items())] key = [args, frozenset(kwargs.items())]
if self.is_method: if self.is_method:
key.append(tuple([getattr(args[0], v, None) for v in self.watch])) key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
key = tuple(key) key = tuple(key) # type: ignore
if key in self.cache: if key in self.cache:
return self.cache[key] return self.cache[key]
try: try:
@ -164,23 +169,25 @@ class _memoized:
# Better to not cache than to blow up entirely. # Better to not cache than to blow up entirely.
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
def __repr__(self): def __repr__(self) -> str:
"""Return the function's docstring.""" """Return the function's docstring."""
return self.func.__doc__ return self.func.__doc__ or ""
def __get__(self, obj, objtype): def __get__(self, obj: Any, objtype: Type) -> functools.partial:
if not self.is_method: if not self.is_method:
self.is_method = True self.is_method = True
"""Support instance methods.""" """Support instance methods."""
return functools.partial(self.__call__, obj) return functools.partial(self.__call__, obj)
def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] = None): def memoized(
func: Optional[Callable] = None, watch: Optional[List[str]] = None
) -> Callable:
if func: if func:
return _memoized(func) return _memoized(func)
else: else:
def wrapper(f): def wrapper(f: Callable) -> Callable:
return _memoized(f, watch) return _memoized(f, watch)
return wrapper return wrapper
@ -229,7 +236,7 @@ def cast_to_num(value: Union[float, int, str]) -> Optional[Union[float, int]]:
return None return None
def list_minus(l: List, minus: List) -> List: def list_minus(l: List[Any], minus: List[Any]) -> List[Any]:
"""Returns l without what is in minus """Returns l without what is in minus
>>> list_minus([1, 2, 3], [2]) >>> list_minus([1, 2, 3], [2])
@ -284,19 +291,19 @@ def md5_hex(data: str) -> str:
class DashboardEncoder(json.JSONEncoder): class DashboardEncoder(json.JSONEncoder):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.sort_keys = True self.sort_keys = True
# pylint: disable=E0202 # pylint: disable=E0202
def default(self, o): def default(self, o: Any) -> Dict[Any, Any]:
try: try:
vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"} vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"}
return {"__{}__".format(o.__class__.__name__): vals} return {"__{}__".format(o.__class__.__name__): vals}
except Exception: except Exception:
if type(o) == datetime: if type(o) == datetime:
return {"__datetime__": o.replace(microsecond=0).isoformat()} return {"__datetime__": o.replace(microsecond=0).isoformat()}
return json.JSONEncoder(sort_keys=True).default(self, o) return json.JSONEncoder(sort_keys=True).default(o)
def parse_human_timedelta(s: Optional[str]) -> timedelta: def parse_human_timedelta(s: Optional[str]) -> timedelta:
@ -332,28 +339,15 @@ class JSONEncodedDict(TypeDecorator):
impl = TEXT impl = TEXT
def process_bind_param(self, value, dialect): def process_bind_param(
if value is not None: self, value: Optional[Dict[Any, Any]], dialect: str
value = json.dumps(value) ) -> Optional[str]:
return json.dumps(value) if value is not None else None
return value def process_result_value(
self, value: Optional[str], dialect: str
def process_result_value(self, value, dialect): ) -> Optional[Dict[Any, Any]]:
if value is not None: return json.loads(value) if value is not None else None
value = json.loads(value)
return value
def datetime_f(dttm):
"""Formats datetime to take less room when it is recent"""
if dttm:
dttm = dttm.isoformat()
now_iso = datetime.now().isoformat()
if now_iso[:10] == dttm[:10]:
dttm = dttm[11:]
elif now_iso[:4] == dttm[:4]:
dttm = dttm[5:]
return "<nobr>{}</nobr>".format(dttm)
def format_timedelta(td: timedelta) -> str: def format_timedelta(td: timedelta) -> str:
@ -373,7 +367,7 @@ def format_timedelta(td: timedelta) -> str:
return str(td) return str(td)
def base_json_conv(obj): def base_json_conv(obj: Any) -> Any:
if isinstance(obj, memoryview): if isinstance(obj, memoryview):
obj = obj.tobytes() obj = obj.tobytes()
if isinstance(obj, np.int64): if isinstance(obj, np.int64):
@ -397,7 +391,7 @@ def base_json_conv(obj):
return "[bytes]" return "[bytes]"
def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False): def json_iso_dttm_ser(obj: Any, pessimistic: bool = False) -> str:
""" """
json serializer that deals with dates json serializer that deals with dates
@ -420,14 +414,14 @@ def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False):
return obj return obj
def pessimistic_json_iso_dttm_ser(obj): def pessimistic_json_iso_dttm_ser(obj: Any) -> str:
"""Proxy to call json_iso_dttm_ser in a pessimistic way """Proxy to call json_iso_dttm_ser in a pessimistic way
If one of object is not serializable to json, it will still succeed""" If one of object is not serializable to json, it will still succeed"""
return json_iso_dttm_ser(obj, pessimistic=True) return json_iso_dttm_ser(obj, pessimistic=True)
def json_int_dttm_ser(obj): def json_int_dttm_ser(obj: Any) -> float:
"""json serializer that deals with dates""" """json serializer that deals with dates"""
val = base_json_conv(obj) val = base_json_conv(obj)
if val is not None: if val is not None:
@ -441,7 +435,7 @@ def json_int_dttm_ser(obj):
return obj return obj
def json_dumps_w_dates(payload): def json_dumps_w_dates(payload: Dict[Any, Any]) -> str:
return json.dumps(payload, default=json_int_dttm_ser) return json.dumps(payload, default=json_int_dttm_ser)
@ -522,7 +516,7 @@ def readfile(file_path: str) -> Optional[str]:
def generic_find_constraint_name( def generic_find_constraint_name(
table: str, columns: Set[str], referenced: str, db: SQLA table: str, columns: Set[str], referenced: str, db: SQLA
): ) -> Optional[str]:
"""Utility to find a constraint name in alembic migrations""" """Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
@ -530,10 +524,12 @@ def generic_find_constraint_name(
if fk.referred_table.name == referenced and set(fk.column_keys) == columns: if fk.referred_table.name == referenced and set(fk.column_keys) == columns:
return fk.name return fk.name
return None
def generic_find_fk_constraint_name( def generic_find_fk_constraint_name(
table: str, columns: Set[str], referenced: str, insp table: str, columns: Set[str], referenced: str, insp: Inspector
): ) -> Optional[str]:
"""Utility to find a foreign-key constraint name in alembic migrations""" """Utility to find a foreign-key constraint name in alembic migrations"""
for fk in insp.get_foreign_keys(table): for fk in insp.get_foreign_keys(table):
if ( if (
@ -542,8 +538,12 @@ def generic_find_fk_constraint_name(
): ):
return fk["name"] return fk["name"]
return None
def generic_find_fk_constraint_names(table, columns, referenced, insp):
def generic_find_fk_constraint_names(
table: str, columns: Set[str], referenced: str, insp: Inspector
) -> Set[str]:
"""Utility to find foreign-key constraint names in alembic migrations""" """Utility to find foreign-key constraint names in alembic migrations"""
names = set() names = set()
@ -557,13 +557,17 @@ def generic_find_fk_constraint_names(table, columns, referenced, insp):
return names return names
def generic_find_uq_constraint_name(table, columns, insp): def generic_find_uq_constraint_name(
table: str, columns: Set[str], insp: Inspector
) -> Optional[str]:
"""Utility to find a unique constraint name in alembic migrations""" """Utility to find a unique constraint name in alembic migrations"""
for uq in insp.get_unique_constraints(table): for uq in insp.get_unique_constraints(table):
if columns == set(uq["column_names"]): if columns == set(uq["column_names"]):
return uq["name"] return uq["name"]
return None
def get_datasource_full_name( def get_datasource_full_name(
database_name: str, datasource_name: str, schema: Optional[str] = None database_name: str, datasource_name: str, schema: Optional[str] = None
@ -582,30 +586,20 @@ def validate_json(obj: Union[bytes, bytearray, str]) -> None:
raise SupersetException("JSON is not valid") raise SupersetException("JSON is not valid")
def table_has_constraint(table, name, db):
"""Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
for c in t.constraints:
if c.name == name:
return True
return False
class timeout: class timeout:
""" """
To be used in a ``with`` block and timeout its content. To be used in a ``with`` block and timeout its content.
""" """
def __init__(self, seconds=1, error_message="Timeout"): def __init__(self, seconds: int = 1, error_message: str = "Timeout") -> None:
self.seconds = seconds self.seconds = seconds
self.error_message = error_message self.error_message = error_message
def handle_timeout(self, signum, frame): def handle_timeout(self, signum: int, frame: Any) -> None:
logger.error("Process timed out") logger.error("Process timed out")
raise SupersetTimeoutException(self.error_message) raise SupersetTimeoutException(self.error_message)
def __enter__(self): def __enter__(self) -> None:
try: try:
signal.signal(signal.SIGALRM, self.handle_timeout) signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds) signal.alarm(self.seconds)
@ -613,7 +607,7 @@ class timeout:
logger.warning("timeout can't be used in the current context") logger.warning("timeout can't be used in the current context")
logger.exception(ex) logger.exception(ex)
def __exit__(self, type, value, traceback): def __exit__(self, type: Any, value: Any, traceback: TracebackType) -> None:
try: try:
signal.alarm(0) signal.alarm(0)
except ValueError as ex: except ValueError as ex:
@ -621,9 +615,9 @@ class timeout:
logger.exception(ex) logger.exception(ex)
def pessimistic_connection_handling(some_engine): def pessimistic_connection_handling(some_engine: Engine) -> None:
@event.listens_for(some_engine, "engine_connect") @event.listens_for(some_engine, "engine_connect")
def ping_connection(connection, branch): def ping_connection(connection: Connection, branch: bool) -> None:
if branch: if branch:
# 'branch' refers to a sub-connection of a connection, # 'branch' refers to a sub-connection of a connection,
# we don't want to bother pinging on these. # we don't want to bother pinging on these.
@ -670,7 +664,14 @@ class QueryStatus:
TIMED_OUT: str = "timed_out" TIMED_OUT: str = "timed_out"
def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, config): def notify_user_about_perm_udate(
granter: User,
user: User,
role: Role,
datasource: "BaseDatasource",
tpl_name: str,
config: Dict[str, Any],
) -> None:
msg = render_template( msg = render_template(
tpl_name, granter=granter, user=user, role=role, datasource=datasource tpl_name, granter=granter, user=user, role=role, datasource=datasource
) )
@ -762,7 +763,13 @@ def send_email_smtp(
send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun) send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun)
def send_MIME_email(e_from, e_to, mime_msg, config, dryrun=False): def send_MIME_email(
e_from: str,
e_to: List[str],
mime_msg: MIMEMultipart,
config: Dict[str, Any],
dryrun: bool = False,
) -> None:
SMTP_HOST = config["SMTP_HOST"] SMTP_HOST = config["SMTP_HOST"]
SMTP_PORT = config["SMTP_PORT"] SMTP_PORT = config["SMTP_PORT"]
SMTP_USER = config["SMTP_USER"] SMTP_USER = config["SMTP_USER"]
@ -800,7 +807,7 @@ def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]:
return [(v, v) for v in values] return [(v, v) for v in values]
def zlib_compress(data): def zlib_compress(data: Union[bytes, str]) -> bytes:
""" """
Compress things in a py2/3 safe fashion Compress things in a py2/3 safe fashion
>>> json_str = '{"test": 1}' >>> json_str = '{"test": 1}'
@ -827,7 +834,9 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes,
return decompressed.decode("utf-8") if decode else decompressed return decompressed.decode("utf-8") if decode else decompressed
def to_adhoc(filt, expressionType="SIMPLE", clause="where"): def to_adhoc(
filt: Dict[str, Any], expressionType: str = "SIMPLE", clause: str = "where"
) -> Dict[str, Any]:
result = { result = {
"clause": clause.upper(), "clause": clause.upper(),
"expressionType": expressionType, "expressionType": expressionType,
@ -849,7 +858,7 @@ def to_adhoc(filt, expressionType="SIMPLE", clause="where"):
return result return result
def merge_extra_filters(form_data: dict): def merge_extra_filters(form_data: Dict[str, Any]) -> None:
# extra_filters are temporary/contextual filters (using the legacy constructs) # extra_filters are temporary/contextual filters (using the legacy constructs)
# that are external to the slice definition. We use those for dynamic # that are external to the slice definition. We use those for dynamic
# interactive filters like the ones emitted by the "Filter Box" visualization. # interactive filters like the ones emitted by the "Filter Box" visualization.
@ -872,7 +881,7 @@ def merge_extra_filters(form_data: dict):
} }
# Grab list of existing filters 'keyed' on the column and operator # Grab list of existing filters 'keyed' on the column and operator
def get_filter_key(f): def get_filter_key(f: Dict[str, Any]) -> str:
if "expressionType" in f: if "expressionType" in f:
return "{}__{}".format(f["subject"], f["operator"]) return "{}__{}".format(f["subject"], f["operator"])
else: else:
@ -945,7 +954,9 @@ def user_label(user: User) -> Optional[str]:
return None return None
def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs): def get_or_create_db(
database_name: str, sqlalchemy_uri: str, *args: Any, **kwargs: Any
) -> "Database":
from superset import db from superset import db
from superset.models import core as models from superset.models import core as models
@ -996,7 +1007,7 @@ def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
return [get_metric_name(metric) for metric in metrics] return [get_metric_name(metric) for metric in metrics]
def ensure_path_exists(path: str): def ensure_path_exists(path: str) -> None:
try: try:
os.makedirs(path) os.makedirs(path)
except OSError as exc: except OSError as exc:
@ -1119,7 +1130,7 @@ def add_ago_to_since(since: str) -> str:
return since return since
def convert_legacy_filters_into_adhoc(fd): def convert_legacy_filters_into_adhoc(fd: FormData) -> None:
mapping = {"having": "having_filters", "where": "filters"} mapping = {"having": "having_filters", "where": "filters"}
if not fd.get("adhoc_filters"): if not fd.get("adhoc_filters"):
@ -1138,7 +1149,7 @@ def convert_legacy_filters_into_adhoc(fd):
del fd[key] del fd[key]
def split_adhoc_filters_into_base_filters(fd): def split_adhoc_filters_into_base_filters(fd: FormData) -> None:
""" """
Mutates form data to restructure the adhoc filters in the form of the four base Mutates form data to restructure the adhoc filters in the form of the four base
filters, `where`, `having`, `filters`, and `having_filters` which represent filters, `where`, `having`, `filters`, and `having_filters` which represent
@ -1230,7 +1241,7 @@ def create_ssl_cert_file(certificate: str) -> str:
return path return path
def time_function(func: Callable, *args, **kwargs) -> Tuple[float, Any]: def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]:
""" """
Measures the amount of time a function takes to execute in ms Measures the amount of time a function takes to execute in ms
@ -1296,7 +1307,7 @@ def split(
yield s[i:] yield s[i:]
def get_iterable(x: Any) -> List: def get_iterable(x: Any) -> List[Any]:
""" """
Get an iterable (list) representation of the object. Get an iterable (list) representation of the object.

View File

@ -17,14 +17,16 @@
import json import json
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import Dict, List from typing import Any, Dict, List
from superset.models.slice import Slice from superset.models.slice import Slice
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]): def convert_filter_scopes(
json_metadata: Dict[Any, Any], filters: List[Slice]
) -> Dict[int, Dict[str, Dict[str, Any]]]:
filter_scopes = {} filter_scopes = {}
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or [] immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
immuned_by_column: Dict = defaultdict(list) immuned_by_column: Dict = defaultdict(list)
@ -34,7 +36,9 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
for column in columns: for column in columns:
immuned_by_column[column].append(int(slice_id)) immuned_by_column[column].append(int(slice_id))
def add_filter_scope(filter_field, filter_id): def add_filter_scope(
filter_fields: Dict[str, Dict[str, Any]], filter_field: str, filter_id: int
) -> None:
# in case filter field is invalid # in case filter field is invalid
if isinstance(filter_field, str): if isinstance(filter_field, str):
current_filter_immune = list( current_filter_immune = list(
@ -54,17 +58,17 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
configs = slice_params.get("filter_configs") or [] configs = slice_params.get("filter_configs") or []
if slice_params.get("date_filter"): if slice_params.get("date_filter"):
add_filter_scope("__time_range", filter_id) add_filter_scope(filter_fields, "__time_range", filter_id)
if slice_params.get("show_sqla_time_column"): if slice_params.get("show_sqla_time_column"):
add_filter_scope("__time_col", filter_id) add_filter_scope(filter_fields, "__time_col", filter_id)
if slice_params.get("show_sqla_time_granularity"): if slice_params.get("show_sqla_time_granularity"):
add_filter_scope("__time_grain", filter_id) add_filter_scope(filter_fields, "__time_grain", filter_id)
if slice_params.get("show_druid_time_granularity"): if slice_params.get("show_druid_time_granularity"):
add_filter_scope("__granularity", filter_id) add_filter_scope(filter_fields, "__granularity", filter_id)
if slice_params.get("show_druid_time_origin"): if slice_params.get("show_druid_time_origin"):
add_filter_scope("druid_time_origin", filter_id) add_filter_scope(filter_fields, "druid_time_origin", filter_id)
for config in configs: for config in configs:
add_filter_scope(config.get("column"), filter_id) add_filter_scope(filter_fields, config.get("column"), filter_id)
if filter_fields: if filter_fields:
filter_scopes[filter_id] = filter_fields filter_scopes[filter_id] = filter_fields

View File

@ -19,6 +19,10 @@ import json
import logging import logging
import time import time
from datetime import datetime from datetime import datetime
from io import BytesIO
from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
@ -27,7 +31,7 @@ from superset.models.slice import Slice
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def decode_dashboards(o): def decode_dashboards(o: Dict[str, Any]) -> Any:
""" """
Function to be passed into json.loads obj_hook parameter Function to be passed into json.loads obj_hook parameter
Recreates the dashboard object from a json representation. Recreates the dashboard object from a json representation.
@ -50,7 +54,9 @@ def decode_dashboards(o):
return o return o
def import_dashboards(session, data_stream, import_time=None): def import_dashboards(
session: Session, data_stream: BytesIO, import_time: Optional[int] = None
) -> None:
"""Imports dashboards from a stream to databases""" """Imports dashboards from a stream to databases"""
current_tt = int(time.time()) current_tt = int(time.time())
import_time = current_tt if import_time is None else import_time import_time = current_tt if import_time is None else import_time
@ -64,7 +70,7 @@ def import_dashboards(session, data_stream, import_time=None):
session.commit() session.commit()
def export_dashboards(session): def export_dashboards(session: Session) -> str:
"""Returns all dashboards metadata as a json dump""" """Returns all dashboards metadata as a json dump"""
logger.info("Starting export") logger.info("Starting export")
dashboards = session.query(Dashboard) dashboards = session.query(Dashboard)

View File

@ -21,7 +21,7 @@ import pytz
EPOCH = datetime(1970, 1, 1) EPOCH = datetime(1970, 1, 1)
def datetime_to_epoch(dttm): def datetime_to_epoch(dttm: datetime) -> float:
if dttm.tzinfo: if dttm.tzinfo:
dttm = dttm.replace(tzinfo=pytz.utc) dttm = dttm.replace(tzinfo=pytz.utc)
epoch_with_tz = pytz.utc.localize(EPOCH) epoch_with_tz = pytz.utc.localize(EPOCH)
@ -29,5 +29,5 @@ def datetime_to_epoch(dttm):
return (dttm - EPOCH).total_seconds() * 1000 return (dttm - EPOCH).total_seconds() * 1000
def now_as_float(): def now_as_float() -> float:
return datetime_to_epoch(datetime.utcnow()) return datetime_to_epoch(datetime.utcnow())

View File

@ -17,11 +17,14 @@
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps from functools import wraps
from typing import Any, Callable, Iterator
from contextlib2 import contextmanager from contextlib2 import contextmanager
from flask import request from flask import request
from werkzeug.wrappers.etag import ETagResponseMixin
from superset import app, cache from superset import app, cache
from superset.stats_logger import BaseStatsLogger
from superset.utils.dates import now_as_float from superset.utils.dates import now_as_float
# If a user sets `max_age` to 0, for long the browser should cache the # If a user sets `max_age` to 0, for long the browser should cache the
@ -32,7 +35,7 @@ logger = logging.getLogger(__name__)
@contextmanager @contextmanager
def stats_timing(stats_key, stats_logger): def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[float]:
"""Provide a transactional scope around a series of operations.""" """Provide a transactional scope around a series of operations."""
start_ts = now_as_float() start_ts = now_as_float()
try: try:
@ -43,7 +46,7 @@ def stats_timing(stats_key, stats_logger):
stats_logger.timing(stats_key, now_as_float() - start_ts) stats_logger.timing(stats_key, now_as_float() - start_ts)
def etag_cache(max_age, check_perms=bool): def etag_cache(max_age: int, check_perms: Callable) -> Callable:
""" """
A decorator for caching views and handling etag conditional requests. A decorator for caching views and handling etag conditional requests.
@ -57,9 +60,9 @@ def etag_cache(max_age, check_perms=bool):
""" """
def decorator(f): def decorator(f: Callable) -> Callable:
@wraps(f) @wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
# check if the user can access the resource # check if the user can access the resource
check_perms(*args, **kwargs) check_perms(*args, **kwargs)
@ -77,7 +80,9 @@ def etag_cache(max_age, check_perms=bool):
key_args = list(args) key_args = list(args)
key_kwargs = kwargs.copy() key_kwargs = kwargs.copy()
key_kwargs.update(request.args) key_kwargs.update(request.args)
cache_key = wrapper.make_cache_key(f, *key_args, **key_kwargs) cache_key = wrapper.make_cache_key( # type: ignore
f, *key_args, **key_kwargs
)
response = cache.get(cache_key) response = cache.get(cache_key)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
if app.debug: if app.debug:
@ -109,9 +114,9 @@ def etag_cache(max_age, check_perms=bool):
return response.make_conditional(request) return response.make_conditional(request)
if cache: if cache:
wrapper.uncached = f wrapper.uncached = f # type: ignore
wrapper.cache_timeout = max_age wrapper.cache_timeout = max_age # type: ignore
wrapper.make_cache_key = cache._memoize_make_cache_key( # pylint: disable=protected-access wrapper.make_cache_key = cache._memoize_make_cache_key( # type: ignore # pylint: disable=protected-access
make_name=None, timeout=max_age make_name=None, timeout=max_age
) )

View File

@ -16,6 +16,9 @@
# under the License. # under the License.
# pylint: disable=C,R,W # pylint: disable=C,R,W
import logging import logging
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from superset.connectors.druid.models import DruidCluster from superset.connectors.druid.models import DruidCluster
from superset.models.core import Database from superset.models.core import Database
@ -25,7 +28,7 @@ DRUID_CLUSTERS_KEY = "druid_clusters"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def export_schema_to_dict(back_references): def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
"""Exports the supported import/export schema to a dictionary""" """Exports the supported import/export schema to a dictionary"""
databases = [ databases = [
Database.export_schema(recursive=True, include_parent_ref=back_references) Database.export_schema(recursive=True, include_parent_ref=back_references)
@ -41,7 +44,9 @@ def export_schema_to_dict(back_references):
return data return data
def export_to_dict(session, recursive, back_references, include_defaults): def export_to_dict(
session: Session, recursive: bool, back_references: bool, include_defaults: bool
) -> Dict[str, Any]:
"""Exports databases and druid clusters to a dictionary""" """Exports databases and druid clusters to a dictionary"""
logger.info("Starting export") logger.info("Starting export")
dbs = session.query(Database) dbs = session.query(Database)
@ -72,8 +77,12 @@ def export_to_dict(session, recursive, back_references, include_defaults):
return data return data
def import_from_dict(session, data, sync=[]): def import_from_dict(
session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
) -> None:
"""Imports databases and druid clusters from dictionary""" """Imports databases and druid clusters from dictionary"""
if not sync:
sync = []
if isinstance(data, dict): if isinstance(data, dict):
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []): for database in data.get(DATABASES_KEY, []):

View File

@ -15,25 +15,33 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict
from flask import Flask
class FeatureFlagManager: class FeatureFlagManager:
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self._get_feature_flags_func = None self._get_feature_flags_func = None
self._feature_flags = None self._feature_flags: Dict[str, Any] = {}
def init_app(self, app): def init_app(self, app: Flask) -> None:
self._get_feature_flags_func = app.config.get("GET_FEATURE_FLAGS_FUNC") self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"]
self._feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {} self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"]
self._feature_flags.update(app.config.get("FEATURE_FLAGS") or {}) self._feature_flags.update(app.config["FEATURE_FLAGS"])
def get_feature_flags(self): def get_feature_flags(self) -> Dict[str, Any]:
if self._get_feature_flags_func: if self._get_feature_flags_func:
return self._get_feature_flags_func(deepcopy(self._feature_flags)) return self._get_feature_flags_func(deepcopy(self._feature_flags))
return self._feature_flags return self._feature_flags
def is_feature_enabled(self, feature) -> bool: def is_feature_enabled(self, feature: str) -> bool:
"""Utility function for checking whether a feature is turned on""" """Utility function for checking whether a feature is turned on"""
return self.get_feature_flags().get(feature) feature_flags = self.get_feature_flags()
if feature_flags and feature in feature_flags:
return feature_flags[feature]
return False

View File

@ -21,19 +21,23 @@ import logging
import textwrap import textwrap
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Any, cast, Type from typing import Any, Callable, cast, Optional, Type
from flask import current_app, g, request from flask import current_app, g, request
from superset.stats_logger import BaseStatsLogger
class AbstractEventLogger(ABC): class AbstractEventLogger(ABC):
@abstractmethod @abstractmethod
def log(self, user_id, action, *args, **kwargs): def log(
self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
) -> None:
pass pass
def log_this(self, f): def log_this(self, f: Callable) -> Callable:
@functools.wraps(f) @functools.wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args: Any, **kwargs: Any) -> Any:
user_id = None user_id = None
if g.user: if g.user:
user_id = g.user.get_id() user_id = g.user.get_id()
@ -49,7 +53,12 @@ class AbstractEventLogger(ABC):
try: try:
slice_id = int( slice_id = int(
slice_id or json.loads(form_data.get("form_data")).get("slice_id") slice_id
or json.loads(
form_data.get("form_data") # type: ignore
).get(
"slice_id"
)
) )
except (ValueError, TypeError): except (ValueError, TypeError):
slice_id = 0 slice_id = 0
@ -62,7 +71,7 @@ class AbstractEventLogger(ABC):
# bulk insert # bulk insert
try: try:
explode_by = form_data.get("explode") explode_by = form_data.get("explode")
records = json.loads(form_data.get(explode_by)) records = json.loads(form_data.get(explode_by)) # type: ignore
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
records = [form_data] records = [form_data]
@ -82,11 +91,11 @@ class AbstractEventLogger(ABC):
return wrapper return wrapper
@property @property
def stats_logger(self): def stats_logger(self) -> BaseStatsLogger:
return current_app.config["STATS_LOGGER"] return current_app.config["STATS_LOGGER"]
def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger: def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
""" """
This function implements the deprecation of assignment This function implements the deprecation of assignment
of class objects to EVENT_LOGGER configuration, and validates of class objects to EVENT_LOGGER configuration, and validates
@ -130,7 +139,9 @@ def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger:
class DBEventLogger(AbstractEventLogger): class DBEventLogger(AbstractEventLogger):
def log(self, user_id, action, *args, **kwargs): # pylint: disable=too-many-locals def log( # pylint: disable=too-many-locals
self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
) -> None:
from superset.models.core import Log from superset.models.core import Log
records = kwargs.get("records", list()) records = kwargs.get("records", list())
@ -141,6 +152,7 @@ class DBEventLogger(AbstractEventLogger):
logs = list() logs = list()
for record in records: for record in records:
json_string: Optional[str]
try: try:
json_string = json.dumps(record) json_string = json.dumps(record)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except

View File

@ -73,8 +73,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
def validate_column_args(*argnames: str) -> Callable: def validate_column_args(*argnames: str) -> Callable:
def wrapper(func): def wrapper(func: Callable) -> Callable:
def wrapped(df, **options): def wrapped(df: DataFrame, **options: Any) -> Any:
columns = df.columns.tolist() columns = df.columns.tolist()
for name in argnames: for name in argnames:
if name in options and not all( if name in options and not all(
@ -159,7 +159,7 @@ def pivot( # pylint: disable=too-many-arguments
metric_fill_value: Optional[Any] = None, metric_fill_value: Optional[Any] = None,
column_fill_value: Optional[str] = None, column_fill_value: Optional[str] = None,
drop_missing_columns: Optional[bool] = True, drop_missing_columns: Optional[bool] = True,
combine_value_with_metric=False, combine_value_with_metric: bool = False,
marginal_distributions: Optional[bool] = None, marginal_distributions: Optional[bool] = None,
marginal_distribution_name: Optional[str] = None, marginal_distribution_name: Optional[str] = None,
) -> DataFrame: ) -> DataFrame:

View File

@ -18,7 +18,7 @@ import logging
import time import time
import urllib.parse import urllib.parse
from io import BytesIO from io import BytesIO
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from flask import current_app, request, Response, session, url_for from flask import current_app, request, Response, session, url_for
from flask_login import login_user from flask_login import login_user
@ -91,7 +91,7 @@ def headless_url(path: str) -> str:
return urllib.parse.urljoin(current_app.config.get("WEBDRIVER_BASEURL", ""), path) return urllib.parse.urljoin(current_app.config.get("WEBDRIVER_BASEURL", ""), path)
def get_url_path(view: str, **kwargs) -> str: def get_url_path(view: str, **kwargs: Any) -> str:
with current_app.test_request_context(): with current_app.test_request_context():
return headless_url(url_for(view, **kwargs)) return headless_url(url_for(view, **kwargs))
@ -135,7 +135,7 @@ class AuthWebDriverProxy:
return self._auth_func(driver, user) return self._auth_func(driver, user)
@staticmethod @staticmethod
def destroy(driver: WebDriver, tries=2): def destroy(driver: WebDriver, tries: int = 2) -> None:
"""Destroy a driver""" """Destroy a driver"""
# This is some very flaky code in selenium. Hence the retries # This is some very flaky code in selenium. Hence the retries
# and catch-all exceptions # and catch-all exceptions

View File

@ -14,22 +14,24 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from werkzeug.routing import BaseConverter from typing import Any, List
from werkzeug.routing import BaseConverter, Map
from superset.models.tags import ObjectTypes from superset.models.tags import ObjectTypes
class RegexConverter(BaseConverter): class RegexConverter(BaseConverter):
def __init__(self, url_map, *items): def __init__(self, url_map: Map, *items: List[str]) -> None:
super(RegexConverter, self).__init__(url_map) super(RegexConverter, self).__init__(url_map) # type: ignore
self.regex = items[0] self.regex = items[0]
class ObjectTypeConverter(BaseConverter): class ObjectTypeConverter(BaseConverter):
"""Validate that object_type is indeed an object type.""" """Validate that object_type is indeed an object type."""
def to_python(self, value): def to_python(self, value: str) -> Any:
return ObjectTypes[value] return ObjectTypes[value]
def to_url(self, value): def to_url(self, value: Any) -> str:
return value.name return value.name

View File

@ -2164,7 +2164,7 @@ class Superset(BaseSupersetView):
return json_error_response(str(ex)) return json_error_response(str(ex))
spec = mydb.db_engine_spec spec = mydb.db_engine_spec
query_cost_formatters = get_feature_flags().get( query_cost_formatters: Dict[str, Any] = get_feature_flags().get(
"QUERY_COST_FORMATTERS_BY_ENGINE", {} "QUERY_COST_FORMATTERS_BY_ENGINE", {}
) )
query_cost_formatter = query_cost_formatters.get( query_cost_formatter = query_cost_formatters.get(

View File

@ -38,7 +38,6 @@ from superset.utils.core import (
base_json_conv, base_json_conv,
convert_legacy_filters_into_adhoc, convert_legacy_filters_into_adhoc,
create_ssl_cert_file, create_ssl_cert_file,
datetime_f,
format_timedelta, format_timedelta,
get_iterable, get_iterable,
get_email_address_list, get_email_address_list,
@ -560,17 +559,6 @@ class UtilsTestCase(SupersetTestCase):
url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"] url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"]
) )
def test_datetime_f(self):
self.assertEqual(
datetime_f(datetime(1990, 9, 21, 19, 11, 19, 626096)),
"<nobr>1990-09-21T19:11:19.626096</nobr>",
)
self.assertEqual(len(datetime_f(datetime.now())), 28)
self.assertEqual(datetime_f(None), "<nobr>None</nobr>")
iso = datetime.now().isoformat()[:10].split("-")
[a, b, c] = [int(v) for v in iso]
self.assertEqual(datetime_f(datetime(a, b, c)), "<nobr>00:00:00</nobr>")
def test_format_timedelta(self): def test_format_timedelta(self):
self.assertEqual(format_timedelta(timedelta(0)), "0:00:00") self.assertEqual(format_timedelta(timedelta(0)), "0:00:00")
self.assertEqual(format_timedelta(timedelta(days=1)), "1 day, 0:00:00") self.assertEqual(format_timedelta(timedelta(days=1)), "1 day, 0:00:00")