[mypy] Enforcing typing for superset.utils (#9905)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-05-27 22:57:30 -07:00 committed by GitHub
parent 54dced1cf6
commit b296a0f250
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 194 additions and 148 deletions

View File

@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -279,7 +279,7 @@ LANGUAGES = {
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
DEFAULT_FEATURE_FLAGS = {
DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
# Experimental feature introducing a client (browser) cache
"CLIENT_CACHE": False,
"ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False,

View File

@ -28,6 +28,7 @@ DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, .
DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
FilterValue = Union[float, int, str]
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
FormData = Dict[str, Any]
Granularity = Union[str, Dict[str, Union[str, float]]]
Metric = Union[Dict[str, str], str]
QueryObjectDict = Dict[str, Any]

View File

@ -14,14 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Callable, Optional
from typing import Any, Callable, Optional
from flask import request
from superset.extensions import cache_manager
def view_cache_key(*_, **__) -> str:
def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-argument
args_hash = hash(frozenset(request.args.items()))
return "view/{}/{}".format(request.path, args_hash)
@ -45,10 +45,10 @@ def memoized_func(
returns the caching key.
"""
def wrap(f):
def wrap(f: Callable) -> Callable:
if cache_manager.tables_cache:
def wrapped_f(self, *args, **kwargs):
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
if not kwargs.get("cache", True):
return f(self, *args, **kwargs)
@ -69,7 +69,7 @@ def memoized_func(
else:
# noop
def wrapped_f(self, *args, **kwargs):
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
return f(self, *args, **kwargs)
return wrapped_f

View File

@ -39,6 +39,7 @@ from email.utils import formatdate
from enum import Enum
from time import struct_time
from timeit import default_timer
from types import TracebackType
from typing import (
Any,
Callable,
@ -51,6 +52,7 @@ from typing import (
Sequence,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
@ -69,10 +71,12 @@ from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from flask import current_app, flash, g, Markup, render_template
from flask_appbuilder import SQLA
from flask_appbuilder.security.sqla.models import User
from flask_appbuilder.security.sqla.models import Role, User
from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import event, exc, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator
@ -81,7 +85,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
from superset.typing import Metric
from superset.typing import FormData, Metric
from superset.utils.dates import datetime_to_epoch, EPOCH
try:
@ -90,6 +94,7 @@ except ImportError:
pass
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
from superset.models.core import Database
@ -121,7 +126,7 @@ except NameError:
pass
def flasher(msg: str, severity: str) -> None:
def flasher(msg: str, severity: str = "message") -> None:
"""Flask's flash if available, logging call if not"""
try:
flash(msg, severity)
@ -142,17 +147,17 @@ class _memoized:
should account for instance variable changes.
"""
def __init__(self, func, watch=()):
def __init__(self, func: Callable, watch: Optional[List[str]] = None) -> None:
self.func = func
self.cache = {}
self.cache: Dict[Any, Any] = {}
self.is_method = False
self.watch = watch or []
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
key = [args, frozenset(kwargs.items())]
if self.is_method:
key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
key = tuple(key)
key = tuple(key) # type: ignore
if key in self.cache:
return self.cache[key]
try:
@ -164,23 +169,25 @@ class _memoized:
# Better to not cache than to blow up entirely.
return self.func(*args, **kwargs)
def __repr__(self):
def __repr__(self) -> str:
"""Return the function's docstring."""
return self.func.__doc__
return self.func.__doc__ or ""
def __get__(self, obj, objtype):
def __get__(self, obj: Any, objtype: Type) -> functools.partial:
if not self.is_method:
self.is_method = True
"""Support instance methods."""
return functools.partial(self.__call__, obj)
def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] = None):
def memoized(
func: Optional[Callable] = None, watch: Optional[List[str]] = None
) -> Callable:
if func:
return _memoized(func)
else:
def wrapper(f):
def wrapper(f: Callable) -> Callable:
return _memoized(f, watch)
return wrapper
@ -229,7 +236,7 @@ def cast_to_num(value: Union[float, int, str]) -> Optional[Union[float, int]]:
return None
def list_minus(l: List, minus: List) -> List:
def list_minus(l: List[Any], minus: List[Any]) -> List[Any]:
"""Returns l without what is in minus
>>> list_minus([1, 2, 3], [2])
@ -284,19 +291,19 @@ def md5_hex(data: str) -> str:
class DashboardEncoder(json.JSONEncoder):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.sort_keys = True
# pylint: disable=E0202
def default(self, o):
def default(self, o: Any) -> Dict[Any, Any]:
try:
vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"}
return {"__{}__".format(o.__class__.__name__): vals}
except Exception:
if type(o) == datetime:
return {"__datetime__": o.replace(microsecond=0).isoformat()}
return json.JSONEncoder(sort_keys=True).default(self, o)
return json.JSONEncoder(sort_keys=True).default(o)
def parse_human_timedelta(s: Optional[str]) -> timedelta:
@ -332,28 +339,15 @@ class JSONEncodedDict(TypeDecorator):
impl = TEXT
def process_bind_param(self, value, dialect):
if value is not None:
value = json.dumps(value)
def process_bind_param(
self, value: Optional[Dict[Any, Any]], dialect: str
) -> Optional[str]:
return json.dumps(value) if value is not None else None
return value
def process_result_value(self, value, dialect):
if value is not None:
value = json.loads(value)
return value
def datetime_f(dttm):
"""Formats datetime to take less room when it is recent"""
if dttm:
dttm = dttm.isoformat()
now_iso = datetime.now().isoformat()
if now_iso[:10] == dttm[:10]:
dttm = dttm[11:]
elif now_iso[:4] == dttm[:4]:
dttm = dttm[5:]
return "<nobr>{}</nobr>".format(dttm)
def process_result_value(
self, value: Optional[str], dialect: str
) -> Optional[Dict[Any, Any]]:
return json.loads(value) if value is not None else None
def format_timedelta(td: timedelta) -> str:
@ -373,7 +367,7 @@ def format_timedelta(td: timedelta) -> str:
return str(td)
def base_json_conv(obj):
def base_json_conv(obj: Any) -> Any:
if isinstance(obj, memoryview):
obj = obj.tobytes()
if isinstance(obj, np.int64):
@ -397,7 +391,7 @@ def base_json_conv(obj):
return "[bytes]"
def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False):
def json_iso_dttm_ser(obj: Any, pessimistic: bool = False) -> str:
"""
json serializer that deals with dates
@ -420,14 +414,14 @@ def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False):
return obj
def pessimistic_json_iso_dttm_ser(obj):
def pessimistic_json_iso_dttm_ser(obj: Any) -> str:
"""Proxy to call json_iso_dttm_ser in a pessimistic way
If one of object is not serializable to json, it will still succeed"""
return json_iso_dttm_ser(obj, pessimistic=True)
def json_int_dttm_ser(obj):
def json_int_dttm_ser(obj: Any) -> float:
"""json serializer that deals with dates"""
val = base_json_conv(obj)
if val is not None:
@ -441,7 +435,7 @@ def json_int_dttm_ser(obj):
return obj
def json_dumps_w_dates(payload):
def json_dumps_w_dates(payload: Dict[Any, Any]) -> str:
return json.dumps(payload, default=json_int_dttm_ser)
@ -522,7 +516,7 @@ def readfile(file_path: str) -> Optional[str]:
def generic_find_constraint_name(
table: str, columns: Set[str], referenced: str, db: SQLA
):
) -> Optional[str]:
"""Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
@ -530,10 +524,12 @@ def generic_find_constraint_name(
if fk.referred_table.name == referenced and set(fk.column_keys) == columns:
return fk.name
return None
def generic_find_fk_constraint_name(
table: str, columns: Set[str], referenced: str, insp
):
table: str, columns: Set[str], referenced: str, insp: Inspector
) -> Optional[str]:
"""Utility to find a foreign-key constraint name in alembic migrations"""
for fk in insp.get_foreign_keys(table):
if (
@ -542,8 +538,12 @@ def generic_find_fk_constraint_name(
):
return fk["name"]
return None
def generic_find_fk_constraint_names(table, columns, referenced, insp):
def generic_find_fk_constraint_names(
table: str, columns: Set[str], referenced: str, insp: Inspector
) -> Set[str]:
"""Utility to find foreign-key constraint names in alembic migrations"""
names = set()
@ -557,13 +557,17 @@ def generic_find_fk_constraint_names(table, columns, referenced, insp):
return names
def generic_find_uq_constraint_name(table, columns, insp):
def generic_find_uq_constraint_name(
table: str, columns: Set[str], insp: Inspector
) -> Optional[str]:
"""Utility to find a unique constraint name in alembic migrations"""
for uq in insp.get_unique_constraints(table):
if columns == set(uq["column_names"]):
return uq["name"]
return None
def get_datasource_full_name(
database_name: str, datasource_name: str, schema: Optional[str] = None
@ -582,30 +586,20 @@ def validate_json(obj: Union[bytes, bytearray, str]) -> None:
raise SupersetException("JSON is not valid")
def table_has_constraint(table, name, db):
"""Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
for c in t.constraints:
if c.name == name:
return True
return False
class timeout:
"""
To be used in a ``with`` block and timeout its content.
"""
def __init__(self, seconds=1, error_message="Timeout"):
def __init__(self, seconds: int = 1, error_message: str = "Timeout") -> None:
self.seconds = seconds
self.error_message = error_message
def handle_timeout(self, signum, frame):
def handle_timeout(self, signum: int, frame: Any) -> None:
logger.error("Process timed out")
raise SupersetTimeoutException(self.error_message)
def __enter__(self):
def __enter__(self) -> None:
try:
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)
@ -613,7 +607,7 @@ class timeout:
logger.warning("timeout can't be used in the current context")
logger.exception(ex)
def __exit__(self, type, value, traceback):
def __exit__(self, type: Any, value: Any, traceback: TracebackType) -> None:
try:
signal.alarm(0)
except ValueError as ex:
@ -621,9 +615,9 @@ class timeout:
logger.exception(ex)
def pessimistic_connection_handling(some_engine):
def pessimistic_connection_handling(some_engine: Engine) -> None:
@event.listens_for(some_engine, "engine_connect")
def ping_connection(connection, branch):
def ping_connection(connection: Connection, branch: bool) -> None:
if branch:
# 'branch' refers to a sub-connection of a connection,
# we don't want to bother pinging on these.
@ -670,7 +664,14 @@ class QueryStatus:
TIMED_OUT: str = "timed_out"
def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, config):
def notify_user_about_perm_udate(
granter: User,
user: User,
role: Role,
datasource: "BaseDatasource",
tpl_name: str,
config: Dict[str, Any],
) -> None:
msg = render_template(
tpl_name, granter=granter, user=user, role=role, datasource=datasource
)
@ -762,7 +763,13 @@ def send_email_smtp(
send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun)
def send_MIME_email(e_from, e_to, mime_msg, config, dryrun=False):
def send_MIME_email(
e_from: str,
e_to: List[str],
mime_msg: MIMEMultipart,
config: Dict[str, Any],
dryrun: bool = False,
) -> None:
SMTP_HOST = config["SMTP_HOST"]
SMTP_PORT = config["SMTP_PORT"]
SMTP_USER = config["SMTP_USER"]
@ -800,7 +807,7 @@ def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]:
return [(v, v) for v in values]
def zlib_compress(data):
def zlib_compress(data: Union[bytes, str]) -> bytes:
"""
Compress things in a py2/3 safe fashion
>>> json_str = '{"test": 1}'
@ -827,7 +834,9 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes,
return decompressed.decode("utf-8") if decode else decompressed
def to_adhoc(filt, expressionType="SIMPLE", clause="where"):
def to_adhoc(
filt: Dict[str, Any], expressionType: str = "SIMPLE", clause: str = "where"
) -> Dict[str, Any]:
result = {
"clause": clause.upper(),
"expressionType": expressionType,
@ -849,7 +858,7 @@ def to_adhoc(filt, expressionType="SIMPLE", clause="where"):
return result
def merge_extra_filters(form_data: dict):
def merge_extra_filters(form_data: Dict[str, Any]) -> None:
# extra_filters are temporary/contextual filters (using the legacy constructs)
# that are external to the slice definition. We use those for dynamic
# interactive filters like the ones emitted by the "Filter Box" visualization.
@ -872,7 +881,7 @@ def merge_extra_filters(form_data: dict):
}
# Grab list of existing filters 'keyed' on the column and operator
def get_filter_key(f):
def get_filter_key(f: Dict[str, Any]) -> str:
if "expressionType" in f:
return "{}__{}".format(f["subject"], f["operator"])
else:
@ -945,7 +954,9 @@ def user_label(user: User) -> Optional[str]:
return None
def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
def get_or_create_db(
database_name: str, sqlalchemy_uri: str, *args: Any, **kwargs: Any
) -> "Database":
from superset import db
from superset.models import core as models
@ -996,7 +1007,7 @@ def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
return [get_metric_name(metric) for metric in metrics]
def ensure_path_exists(path: str):
def ensure_path_exists(path: str) -> None:
try:
os.makedirs(path)
except OSError as exc:
@ -1119,7 +1130,7 @@ def add_ago_to_since(since: str) -> str:
return since
def convert_legacy_filters_into_adhoc(fd):
def convert_legacy_filters_into_adhoc(fd: FormData) -> None:
mapping = {"having": "having_filters", "where": "filters"}
if not fd.get("adhoc_filters"):
@ -1138,7 +1149,7 @@ def convert_legacy_filters_into_adhoc(fd):
del fd[key]
def split_adhoc_filters_into_base_filters(fd):
def split_adhoc_filters_into_base_filters(fd: FormData) -> None:
"""
Mutates form data to restructure the adhoc filters in the form of the four base
filters, `where`, `having`, `filters`, and `having_filters` which represent
@ -1230,7 +1241,7 @@ def create_ssl_cert_file(certificate: str) -> str:
return path
def time_function(func: Callable, *args, **kwargs) -> Tuple[float, Any]:
def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]:
"""
Measures the amount of time a function takes to execute in ms
@ -1296,7 +1307,7 @@ def split(
yield s[i:]
def get_iterable(x: Any) -> List:
def get_iterable(x: Any) -> List[Any]:
"""
Get an iterable (list) representation of the object.

View File

@ -17,14 +17,16 @@
import json
import logging
from collections import defaultdict
from typing import Dict, List
from typing import Any, Dict, List
from superset.models.slice import Slice
logger = logging.getLogger(__name__)
def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
def convert_filter_scopes(
json_metadata: Dict[Any, Any], filters: List[Slice]
) -> Dict[int, Dict[str, Dict[str, Any]]]:
filter_scopes = {}
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
immuned_by_column: Dict = defaultdict(list)
@ -34,7 +36,9 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
for column in columns:
immuned_by_column[column].append(int(slice_id))
def add_filter_scope(filter_field, filter_id):
def add_filter_scope(
filter_fields: Dict[str, Dict[str, Any]], filter_field: str, filter_id: int
) -> None:
# in case filter field is invalid
if isinstance(filter_field, str):
current_filter_immune = list(
@ -54,17 +58,17 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
configs = slice_params.get("filter_configs") or []
if slice_params.get("date_filter"):
add_filter_scope("__time_range", filter_id)
add_filter_scope(filter_fields, "__time_range", filter_id)
if slice_params.get("show_sqla_time_column"):
add_filter_scope("__time_col", filter_id)
add_filter_scope(filter_fields, "__time_col", filter_id)
if slice_params.get("show_sqla_time_granularity"):
add_filter_scope("__time_grain", filter_id)
add_filter_scope(filter_fields, "__time_grain", filter_id)
if slice_params.get("show_druid_time_granularity"):
add_filter_scope("__granularity", filter_id)
add_filter_scope(filter_fields, "__granularity", filter_id)
if slice_params.get("show_druid_time_origin"):
add_filter_scope("druid_time_origin", filter_id)
add_filter_scope(filter_fields, "druid_time_origin", filter_id)
for config in configs:
add_filter_scope(config.get("column"), filter_id)
add_filter_scope(filter_fields, config.get("column"), filter_id)
if filter_fields:
filter_scopes[filter_id] = filter_fields

View File

@ -19,6 +19,10 @@ import json
import logging
import time
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.models.dashboard import Dashboard
@ -27,7 +31,7 @@ from superset.models.slice import Slice
logger = logging.getLogger(__name__)
def decode_dashboards(o):
def decode_dashboards(o: Dict[str, Any]) -> Any:
"""
Function to be passed into json.loads obj_hook parameter
Recreates the dashboard object from a json representation.
@ -50,7 +54,9 @@ def decode_dashboards(o):
return o
def import_dashboards(session, data_stream, import_time=None):
def import_dashboards(
session: Session, data_stream: BytesIO, import_time: Optional[int] = None
) -> None:
"""Imports dashboards from a stream to databases"""
current_tt = int(time.time())
import_time = current_tt if import_time is None else import_time
@ -64,7 +70,7 @@ def import_dashboards(session, data_stream, import_time=None):
session.commit()
def export_dashboards(session):
def export_dashboards(session: Session) -> str:
"""Returns all dashboards metadata as a json dump"""
logger.info("Starting export")
dashboards = session.query(Dashboard)

View File

@ -21,7 +21,7 @@ import pytz
EPOCH = datetime(1970, 1, 1)
def datetime_to_epoch(dttm):
def datetime_to_epoch(dttm: datetime) -> float:
if dttm.tzinfo:
dttm = dttm.replace(tzinfo=pytz.utc)
epoch_with_tz = pytz.utc.localize(EPOCH)
@ -29,5 +29,5 @@ def datetime_to_epoch(dttm):
return (dttm - EPOCH).total_seconds() * 1000
def now_as_float():
def now_as_float() -> float:
return datetime_to_epoch(datetime.utcnow())

View File

@ -17,11 +17,14 @@
import logging
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Callable, Iterator
from contextlib2 import contextmanager
from flask import request
from werkzeug.wrappers.etag import ETagResponseMixin
from superset import app, cache
from superset.stats_logger import BaseStatsLogger
from superset.utils.dates import now_as_float
# If a user sets `max_age` to 0, for long the browser should cache the
@ -32,7 +35,7 @@ logger = logging.getLogger(__name__)
@contextmanager
def stats_timing(stats_key, stats_logger):
def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[float]:
"""Provide a transactional scope around a series of operations."""
start_ts = now_as_float()
try:
@ -43,7 +46,7 @@ def stats_timing(stats_key, stats_logger):
stats_logger.timing(stats_key, now_as_float() - start_ts)
def etag_cache(max_age, check_perms=bool):
def etag_cache(max_age: int, check_perms: Callable) -> Callable:
"""
A decorator for caching views and handling etag conditional requests.
@ -57,9 +60,9 @@ def etag_cache(max_age, check_perms=bool):
"""
def decorator(f):
def decorator(f: Callable) -> Callable:
@wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
# check if the user can access the resource
check_perms(*args, **kwargs)
@ -77,7 +80,9 @@ def etag_cache(max_age, check_perms=bool):
key_args = list(args)
key_kwargs = kwargs.copy()
key_kwargs.update(request.args)
cache_key = wrapper.make_cache_key(f, *key_args, **key_kwargs)
cache_key = wrapper.make_cache_key( # type: ignore
f, *key_args, **key_kwargs
)
response = cache.get(cache_key)
except Exception: # pylint: disable=broad-except
if app.debug:
@ -109,9 +114,9 @@ def etag_cache(max_age, check_perms=bool):
return response.make_conditional(request)
if cache:
wrapper.uncached = f
wrapper.cache_timeout = max_age
wrapper.make_cache_key = cache._memoize_make_cache_key( # pylint: disable=protected-access
wrapper.uncached = f # type: ignore
wrapper.cache_timeout = max_age # type: ignore
wrapper.make_cache_key = cache._memoize_make_cache_key( # type: ignore # pylint: disable=protected-access
make_name=None, timeout=max_age
)

View File

@ -16,6 +16,9 @@
# under the License.
# pylint: disable=C,R,W
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from superset.connectors.druid.models import DruidCluster
from superset.models.core import Database
@ -25,7 +28,7 @@ DRUID_CLUSTERS_KEY = "druid_clusters"
logger = logging.getLogger(__name__)
def export_schema_to_dict(back_references):
def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
"""Exports the supported import/export schema to a dictionary"""
databases = [
Database.export_schema(recursive=True, include_parent_ref=back_references)
@ -41,7 +44,9 @@ def export_schema_to_dict(back_references):
return data
def export_to_dict(session, recursive, back_references, include_defaults):
def export_to_dict(
session: Session, recursive: bool, back_references: bool, include_defaults: bool
) -> Dict[str, Any]:
"""Exports databases and druid clusters to a dictionary"""
logger.info("Starting export")
dbs = session.query(Database)
@ -72,8 +77,12 @@ def export_to_dict(session, recursive, back_references, include_defaults):
return data
def import_from_dict(session, data, sync=[]):
def import_from_dict(
session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
) -> None:
"""Imports databases and druid clusters from dictionary"""
if not sync:
sync = []
if isinstance(data, dict):
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []):

View File

@ -15,25 +15,33 @@
# specific language governing permissions and limitations
# under the License.
from copy import deepcopy
from typing import Any, Dict
from flask import Flask
class FeatureFlagManager:
def __init__(self) -> None:
super().__init__()
self._get_feature_flags_func = None
self._feature_flags = None
self._feature_flags: Dict[str, Any] = {}
def init_app(self, app):
self._get_feature_flags_func = app.config.get("GET_FEATURE_FLAGS_FUNC")
self._feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {}
self._feature_flags.update(app.config.get("FEATURE_FLAGS") or {})
def init_app(self, app: Flask) -> None:
self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"]
self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"]
self._feature_flags.update(app.config["FEATURE_FLAGS"])
def get_feature_flags(self):
def get_feature_flags(self) -> Dict[str, Any]:
if self._get_feature_flags_func:
return self._get_feature_flags_func(deepcopy(self._feature_flags))
return self._feature_flags
def is_feature_enabled(self, feature) -> bool:
def is_feature_enabled(self, feature: str) -> bool:
"""Utility function for checking whether a feature is turned on"""
return self.get_feature_flags().get(feature)
feature_flags = self.get_feature_flags()
if feature_flags and feature in feature_flags:
return feature_flags[feature]
return False

View File

@ -21,19 +21,23 @@ import logging
import textwrap
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, cast, Type
from typing import Any, Callable, cast, Optional, Type
from flask import current_app, g, request
from superset.stats_logger import BaseStatsLogger
class AbstractEventLogger(ABC):
@abstractmethod
def log(self, user_id, action, *args, **kwargs):
def log(
self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
) -> None:
pass
def log_this(self, f):
def log_this(self, f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Any:
user_id = None
if g.user:
user_id = g.user.get_id()
@ -49,7 +53,12 @@ class AbstractEventLogger(ABC):
try:
slice_id = int(
slice_id or json.loads(form_data.get("form_data")).get("slice_id")
slice_id
or json.loads(
form_data.get("form_data") # type: ignore
).get(
"slice_id"
)
)
except (ValueError, TypeError):
slice_id = 0
@ -62,7 +71,7 @@ class AbstractEventLogger(ABC):
# bulk insert
try:
explode_by = form_data.get("explode")
records = json.loads(form_data.get(explode_by))
records = json.loads(form_data.get(explode_by)) # type: ignore
except Exception: # pylint: disable=broad-except
records = [form_data]
@ -82,11 +91,11 @@ class AbstractEventLogger(ABC):
return wrapper
@property
def stats_logger(self):
def stats_logger(self) -> BaseStatsLogger:
return current_app.config["STATS_LOGGER"]
def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger:
def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
"""
This function implements the deprecation of assignment
of class objects to EVENT_LOGGER configuration, and validates
@ -130,7 +139,9 @@ def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger:
class DBEventLogger(AbstractEventLogger):
def log(self, user_id, action, *args, **kwargs): # pylint: disable=too-many-locals
def log( # pylint: disable=too-many-locals
self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
) -> None:
from superset.models.core import Log
records = kwargs.get("records", list())
@ -141,6 +152,7 @@ class DBEventLogger(AbstractEventLogger):
logs = list()
for record in records:
json_string: Optional[str]
try:
json_string = json.dumps(record)
except Exception: # pylint: disable=broad-except

View File

@ -73,8 +73,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
def validate_column_args(*argnames: str) -> Callable:
def wrapper(func):
def wrapped(df, **options):
def wrapper(func: Callable) -> Callable:
def wrapped(df: DataFrame, **options: Any) -> Any:
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
@ -159,7 +159,7 @@ def pivot( # pylint: disable=too-many-arguments
metric_fill_value: Optional[Any] = None,
column_fill_value: Optional[str] = None,
drop_missing_columns: Optional[bool] = True,
combine_value_with_metric=False,
combine_value_with_metric: bool = False,
marginal_distributions: Optional[bool] = None,
marginal_distribution_name: Optional[str] = None,
) -> DataFrame:

View File

@ -18,7 +18,7 @@ import logging
import time
import urllib.parse
from io import BytesIO
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
from flask import current_app, request, Response, session, url_for
from flask_login import login_user
@ -91,7 +91,7 @@ def headless_url(path: str) -> str:
return urllib.parse.urljoin(current_app.config.get("WEBDRIVER_BASEURL", ""), path)
def get_url_path(view: str, **kwargs) -> str:
def get_url_path(view: str, **kwargs: Any) -> str:
with current_app.test_request_context():
return headless_url(url_for(view, **kwargs))
@ -135,7 +135,7 @@ class AuthWebDriverProxy:
return self._auth_func(driver, user)
@staticmethod
def destroy(driver: WebDriver, tries=2):
def destroy(driver: WebDriver, tries: int = 2) -> None:
"""Destroy a driver"""
# This is some very flaky code in selenium. Hence the retries
# and catch-all exceptions

View File

@ -14,22 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from werkzeug.routing import BaseConverter
from typing import Any, List
from werkzeug.routing import BaseConverter, Map
from superset.models.tags import ObjectTypes
class RegexConverter(BaseConverter):
def __init__(self, url_map, *items):
super(RegexConverter, self).__init__(url_map)
def __init__(self, url_map: Map, *items: List[str]) -> None:
super(RegexConverter, self).__init__(url_map) # type: ignore
self.regex = items[0]
class ObjectTypeConverter(BaseConverter):
"""Validate that object_type is indeed an object type."""
def to_python(self, value):
def to_python(self, value: str) -> Any:
return ObjectTypes[value]
def to_url(self, value):
def to_url(self, value: Any) -> str:
return value.name

View File

@ -2164,7 +2164,7 @@ class Superset(BaseSupersetView):
return json_error_response(str(ex))
spec = mydb.db_engine_spec
query_cost_formatters = get_feature_flags().get(
query_cost_formatters: Dict[str, Any] = get_feature_flags().get(
"QUERY_COST_FORMATTERS_BY_ENGINE", {}
)
query_cost_formatter = query_cost_formatters.get(

View File

@ -38,7 +38,6 @@ from superset.utils.core import (
base_json_conv,
convert_legacy_filters_into_adhoc,
create_ssl_cert_file,
datetime_f,
format_timedelta,
get_iterable,
get_email_address_list,
@ -560,17 +559,6 @@ class UtilsTestCase(SupersetTestCase):
url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"]
)
def test_datetime_f(self):
self.assertEqual(
datetime_f(datetime(1990, 9, 21, 19, 11, 19, 626096)),
"<nobr>1990-09-21T19:11:19.626096</nobr>",
)
self.assertEqual(len(datetime_f(datetime.now())), 28)
self.assertEqual(datetime_f(None), "<nobr>None</nobr>")
iso = datetime.now().isoformat()[:10].split("-")
[a, b, c] = [int(v) for v in iso]
self.assertEqual(datetime_f(datetime(a, b, c)), "<nobr>00:00:00</nobr>")
def test_format_timedelta(self):
self.assertEqual(format_timedelta(timedelta(0)), "0:00:00")
self.assertEqual(format_timedelta(timedelta(days=1)), "1 day, 0:00:00")