[mypy] Enforcing typing for superset.utils (#9905)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
54dced1cf6
commit
b296a0f250
|
|
@ -53,7 +53,7 @@ order_by_type = false
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
no_implicit_optional = true
|
no_implicit_optional = true
|
||||||
|
|
||||||
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
|
[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_untyped_calls = true
|
disallow_untyped_calls = true
|
||||||
disallow_untyped_defs = true
|
disallow_untyped_defs = true
|
||||||
|
|
|
||||||
|
|
@ -279,7 +279,7 @@ LANGUAGES = {
|
||||||
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
|
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
|
||||||
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
|
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
|
||||||
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
|
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
|
||||||
DEFAULT_FEATURE_FLAGS = {
|
DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
|
||||||
# Experimental feature introducing a client (browser) cache
|
# Experimental feature introducing a client (browser) cache
|
||||||
"CLIENT_CACHE": False,
|
"CLIENT_CACHE": False,
|
||||||
"ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False,
|
"ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False,
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, .
|
||||||
DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
|
DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
|
||||||
FilterValue = Union[float, int, str]
|
FilterValue = Union[float, int, str]
|
||||||
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
|
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
|
||||||
|
FormData = Dict[str, Any]
|
||||||
Granularity = Union[str, Dict[str, Union[str, float]]]
|
Granularity = Union[str, Dict[str, Union[str, float]]]
|
||||||
Metric = Union[Dict[str, str], str]
|
Metric = Union[Dict[str, str], str]
|
||||||
QueryObjectDict = Dict[str, Any]
|
QueryObjectDict = Dict[str, Any]
|
||||||
|
|
|
||||||
|
|
@ -14,14 +14,14 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
from superset.extensions import cache_manager
|
from superset.extensions import cache_manager
|
||||||
|
|
||||||
|
|
||||||
def view_cache_key(*_, **__) -> str:
|
def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-argument
|
||||||
args_hash = hash(frozenset(request.args.items()))
|
args_hash = hash(frozenset(request.args.items()))
|
||||||
return "view/{}/{}".format(request.path, args_hash)
|
return "view/{}/{}".format(request.path, args_hash)
|
||||||
|
|
||||||
|
|
@ -45,10 +45,10 @@ def memoized_func(
|
||||||
returns the caching key.
|
returns the caching key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrap(f):
|
def wrap(f: Callable) -> Callable:
|
||||||
if cache_manager.tables_cache:
|
if cache_manager.tables_cache:
|
||||||
|
|
||||||
def wrapped_f(self, *args, **kwargs):
|
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
if not kwargs.get("cache", True):
|
if not kwargs.get("cache", True):
|
||||||
return f(self, *args, **kwargs)
|
return f(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -69,7 +69,7 @@ def memoized_func(
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# noop
|
# noop
|
||||||
def wrapped_f(self, *args, **kwargs):
|
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
return f(self, *args, **kwargs)
|
return f(self, *args, **kwargs)
|
||||||
|
|
||||||
return wrapped_f
|
return wrapped_f
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ from email.utils import formatdate
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from time import struct_time
|
from time import struct_time
|
||||||
from timeit import default_timer
|
from timeit import default_timer
|
||||||
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
|
|
@ -51,6 +52,7 @@ from typing import (
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Type,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
@ -69,10 +71,12 @@ from dateutil.parser import parse
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
from flask import current_app, flash, g, Markup, render_template
|
from flask import current_app, flash, g, Markup, render_template
|
||||||
from flask_appbuilder import SQLA
|
from flask_appbuilder import SQLA
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import Role, User
|
||||||
from flask_babel import gettext as __, lazy_gettext as _
|
from flask_babel import gettext as __, lazy_gettext as _
|
||||||
from sqlalchemy import event, exc, select, Text
|
from sqlalchemy import event, exc, select, Text
|
||||||
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
||||||
|
from sqlalchemy.engine import Connection, Engine
|
||||||
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
from sqlalchemy.sql.type_api import Variant
|
from sqlalchemy.sql.type_api import Variant
|
||||||
from sqlalchemy.types import TEXT, TypeDecorator
|
from sqlalchemy.types import TEXT, TypeDecorator
|
||||||
|
|
||||||
|
|
@ -81,7 +85,7 @@ from superset.exceptions import (
|
||||||
SupersetException,
|
SupersetException,
|
||||||
SupersetTimeoutException,
|
SupersetTimeoutException,
|
||||||
)
|
)
|
||||||
from superset.typing import Metric
|
from superset.typing import FormData, Metric
|
||||||
from superset.utils.dates import datetime_to_epoch, EPOCH
|
from superset.utils.dates import datetime_to_epoch, EPOCH
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -90,6 +94,7 @@ except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from superset.connectors.base.models import BaseDatasource
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -121,7 +126,7 @@ except NameError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def flasher(msg: str, severity: str) -> None:
|
def flasher(msg: str, severity: str = "message") -> None:
|
||||||
"""Flask's flash if available, logging call if not"""
|
"""Flask's flash if available, logging call if not"""
|
||||||
try:
|
try:
|
||||||
flash(msg, severity)
|
flash(msg, severity)
|
||||||
|
|
@ -142,17 +147,17 @@ class _memoized:
|
||||||
should account for instance variable changes.
|
should account for instance variable changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, func, watch=()):
|
def __init__(self, func: Callable, watch: Optional[List[str]] = None) -> None:
|
||||||
self.func = func
|
self.func = func
|
||||||
self.cache = {}
|
self.cache: Dict[Any, Any] = {}
|
||||||
self.is_method = False
|
self.is_method = False
|
||||||
self.watch = watch or []
|
self.watch = watch or []
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
key = [args, frozenset(kwargs.items())]
|
key = [args, frozenset(kwargs.items())]
|
||||||
if self.is_method:
|
if self.is_method:
|
||||||
key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
|
key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
|
||||||
key = tuple(key)
|
key = tuple(key) # type: ignore
|
||||||
if key in self.cache:
|
if key in self.cache:
|
||||||
return self.cache[key]
|
return self.cache[key]
|
||||||
try:
|
try:
|
||||||
|
|
@ -164,23 +169,25 @@ class _memoized:
|
||||||
# Better to not cache than to blow up entirely.
|
# Better to not cache than to blow up entirely.
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
"""Return the function's docstring."""
|
"""Return the function's docstring."""
|
||||||
return self.func.__doc__
|
return self.func.__doc__ or ""
|
||||||
|
|
||||||
def __get__(self, obj, objtype):
|
def __get__(self, obj: Any, objtype: Type) -> functools.partial:
|
||||||
if not self.is_method:
|
if not self.is_method:
|
||||||
self.is_method = True
|
self.is_method = True
|
||||||
"""Support instance methods."""
|
"""Support instance methods."""
|
||||||
return functools.partial(self.__call__, obj)
|
return functools.partial(self.__call__, obj)
|
||||||
|
|
||||||
|
|
||||||
def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] = None):
|
def memoized(
|
||||||
|
func: Optional[Callable] = None, watch: Optional[List[str]] = None
|
||||||
|
) -> Callable:
|
||||||
if func:
|
if func:
|
||||||
return _memoized(func)
|
return _memoized(func)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def wrapper(f):
|
def wrapper(f: Callable) -> Callable:
|
||||||
return _memoized(f, watch)
|
return _memoized(f, watch)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
@ -229,7 +236,7 @@ def cast_to_num(value: Union[float, int, str]) -> Optional[Union[float, int]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def list_minus(l: List, minus: List) -> List:
|
def list_minus(l: List[Any], minus: List[Any]) -> List[Any]:
|
||||||
"""Returns l without what is in minus
|
"""Returns l without what is in minus
|
||||||
|
|
||||||
>>> list_minus([1, 2, 3], [2])
|
>>> list_minus([1, 2, 3], [2])
|
||||||
|
|
@ -284,19 +291,19 @@ def md5_hex(data: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class DashboardEncoder(json.JSONEncoder):
|
class DashboardEncoder(json.JSONEncoder):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.sort_keys = True
|
self.sort_keys = True
|
||||||
|
|
||||||
# pylint: disable=E0202
|
# pylint: disable=E0202
|
||||||
def default(self, o):
|
def default(self, o: Any) -> Dict[Any, Any]:
|
||||||
try:
|
try:
|
||||||
vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"}
|
vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"}
|
||||||
return {"__{}__".format(o.__class__.__name__): vals}
|
return {"__{}__".format(o.__class__.__name__): vals}
|
||||||
except Exception:
|
except Exception:
|
||||||
if type(o) == datetime:
|
if type(o) == datetime:
|
||||||
return {"__datetime__": o.replace(microsecond=0).isoformat()}
|
return {"__datetime__": o.replace(microsecond=0).isoformat()}
|
||||||
return json.JSONEncoder(sort_keys=True).default(self, o)
|
return json.JSONEncoder(sort_keys=True).default(o)
|
||||||
|
|
||||||
|
|
||||||
def parse_human_timedelta(s: Optional[str]) -> timedelta:
|
def parse_human_timedelta(s: Optional[str]) -> timedelta:
|
||||||
|
|
@ -332,28 +339,15 @@ class JSONEncodedDict(TypeDecorator):
|
||||||
|
|
||||||
impl = TEXT
|
impl = TEXT
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect):
|
def process_bind_param(
|
||||||
if value is not None:
|
self, value: Optional[Dict[Any, Any]], dialect: str
|
||||||
value = json.dumps(value)
|
) -> Optional[str]:
|
||||||
|
return json.dumps(value) if value is not None else None
|
||||||
|
|
||||||
return value
|
def process_result_value(
|
||||||
|
self, value: Optional[str], dialect: str
|
||||||
def process_result_value(self, value, dialect):
|
) -> Optional[Dict[Any, Any]]:
|
||||||
if value is not None:
|
return json.loads(value) if value is not None else None
|
||||||
value = json.loads(value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def datetime_f(dttm):
|
|
||||||
"""Formats datetime to take less room when it is recent"""
|
|
||||||
if dttm:
|
|
||||||
dttm = dttm.isoformat()
|
|
||||||
now_iso = datetime.now().isoformat()
|
|
||||||
if now_iso[:10] == dttm[:10]:
|
|
||||||
dttm = dttm[11:]
|
|
||||||
elif now_iso[:4] == dttm[:4]:
|
|
||||||
dttm = dttm[5:]
|
|
||||||
return "<nobr>{}</nobr>".format(dttm)
|
|
||||||
|
|
||||||
|
|
||||||
def format_timedelta(td: timedelta) -> str:
|
def format_timedelta(td: timedelta) -> str:
|
||||||
|
|
@ -373,7 +367,7 @@ def format_timedelta(td: timedelta) -> str:
|
||||||
return str(td)
|
return str(td)
|
||||||
|
|
||||||
|
|
||||||
def base_json_conv(obj):
|
def base_json_conv(obj: Any) -> Any:
|
||||||
if isinstance(obj, memoryview):
|
if isinstance(obj, memoryview):
|
||||||
obj = obj.tobytes()
|
obj = obj.tobytes()
|
||||||
if isinstance(obj, np.int64):
|
if isinstance(obj, np.int64):
|
||||||
|
|
@ -397,7 +391,7 @@ def base_json_conv(obj):
|
||||||
return "[bytes]"
|
return "[bytes]"
|
||||||
|
|
||||||
|
|
||||||
def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False):
|
def json_iso_dttm_ser(obj: Any, pessimistic: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
json serializer that deals with dates
|
json serializer that deals with dates
|
||||||
|
|
||||||
|
|
@ -420,14 +414,14 @@ def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def pessimistic_json_iso_dttm_ser(obj):
|
def pessimistic_json_iso_dttm_ser(obj: Any) -> str:
|
||||||
"""Proxy to call json_iso_dttm_ser in a pessimistic way
|
"""Proxy to call json_iso_dttm_ser in a pessimistic way
|
||||||
|
|
||||||
If one of object is not serializable to json, it will still succeed"""
|
If one of object is not serializable to json, it will still succeed"""
|
||||||
return json_iso_dttm_ser(obj, pessimistic=True)
|
return json_iso_dttm_ser(obj, pessimistic=True)
|
||||||
|
|
||||||
|
|
||||||
def json_int_dttm_ser(obj):
|
def json_int_dttm_ser(obj: Any) -> float:
|
||||||
"""json serializer that deals with dates"""
|
"""json serializer that deals with dates"""
|
||||||
val = base_json_conv(obj)
|
val = base_json_conv(obj)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
|
|
@ -441,7 +435,7 @@ def json_int_dttm_ser(obj):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def json_dumps_w_dates(payload):
|
def json_dumps_w_dates(payload: Dict[Any, Any]) -> str:
|
||||||
return json.dumps(payload, default=json_int_dttm_ser)
|
return json.dumps(payload, default=json_int_dttm_ser)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -522,7 +516,7 @@ def readfile(file_path: str) -> Optional[str]:
|
||||||
|
|
||||||
def generic_find_constraint_name(
|
def generic_find_constraint_name(
|
||||||
table: str, columns: Set[str], referenced: str, db: SQLA
|
table: str, columns: Set[str], referenced: str, db: SQLA
|
||||||
):
|
) -> Optional[str]:
|
||||||
"""Utility to find a constraint name in alembic migrations"""
|
"""Utility to find a constraint name in alembic migrations"""
|
||||||
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
|
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
|
||||||
|
|
||||||
|
|
@ -530,10 +524,12 @@ def generic_find_constraint_name(
|
||||||
if fk.referred_table.name == referenced and set(fk.column_keys) == columns:
|
if fk.referred_table.name == referenced and set(fk.column_keys) == columns:
|
||||||
return fk.name
|
return fk.name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def generic_find_fk_constraint_name(
|
def generic_find_fk_constraint_name(
|
||||||
table: str, columns: Set[str], referenced: str, insp
|
table: str, columns: Set[str], referenced: str, insp: Inspector
|
||||||
):
|
) -> Optional[str]:
|
||||||
"""Utility to find a foreign-key constraint name in alembic migrations"""
|
"""Utility to find a foreign-key constraint name in alembic migrations"""
|
||||||
for fk in insp.get_foreign_keys(table):
|
for fk in insp.get_foreign_keys(table):
|
||||||
if (
|
if (
|
||||||
|
|
@ -542,8 +538,12 @@ def generic_find_fk_constraint_name(
|
||||||
):
|
):
|
||||||
return fk["name"]
|
return fk["name"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def generic_find_fk_constraint_names(table, columns, referenced, insp):
|
|
||||||
|
def generic_find_fk_constraint_names(
|
||||||
|
table: str, columns: Set[str], referenced: str, insp: Inspector
|
||||||
|
) -> Set[str]:
|
||||||
"""Utility to find foreign-key constraint names in alembic migrations"""
|
"""Utility to find foreign-key constraint names in alembic migrations"""
|
||||||
names = set()
|
names = set()
|
||||||
|
|
||||||
|
|
@ -557,13 +557,17 @@ def generic_find_fk_constraint_names(table, columns, referenced, insp):
|
||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
def generic_find_uq_constraint_name(table, columns, insp):
|
def generic_find_uq_constraint_name(
|
||||||
|
table: str, columns: Set[str], insp: Inspector
|
||||||
|
) -> Optional[str]:
|
||||||
"""Utility to find a unique constraint name in alembic migrations"""
|
"""Utility to find a unique constraint name in alembic migrations"""
|
||||||
|
|
||||||
for uq in insp.get_unique_constraints(table):
|
for uq in insp.get_unique_constraints(table):
|
||||||
if columns == set(uq["column_names"]):
|
if columns == set(uq["column_names"]):
|
||||||
return uq["name"]
|
return uq["name"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_datasource_full_name(
|
def get_datasource_full_name(
|
||||||
database_name: str, datasource_name: str, schema: Optional[str] = None
|
database_name: str, datasource_name: str, schema: Optional[str] = None
|
||||||
|
|
@ -582,30 +586,20 @@ def validate_json(obj: Union[bytes, bytearray, str]) -> None:
|
||||||
raise SupersetException("JSON is not valid")
|
raise SupersetException("JSON is not valid")
|
||||||
|
|
||||||
|
|
||||||
def table_has_constraint(table, name, db):
|
|
||||||
"""Utility to find a constraint name in alembic migrations"""
|
|
||||||
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
|
|
||||||
|
|
||||||
for c in t.constraints:
|
|
||||||
if c.name == name:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class timeout:
|
class timeout:
|
||||||
"""
|
"""
|
||||||
To be used in a ``with`` block and timeout its content.
|
To be used in a ``with`` block and timeout its content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, seconds=1, error_message="Timeout"):
|
def __init__(self, seconds: int = 1, error_message: str = "Timeout") -> None:
|
||||||
self.seconds = seconds
|
self.seconds = seconds
|
||||||
self.error_message = error_message
|
self.error_message = error_message
|
||||||
|
|
||||||
def handle_timeout(self, signum, frame):
|
def handle_timeout(self, signum: int, frame: Any) -> None:
|
||||||
logger.error("Process timed out")
|
logger.error("Process timed out")
|
||||||
raise SupersetTimeoutException(self.error_message)
|
raise SupersetTimeoutException(self.error_message)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> None:
|
||||||
try:
|
try:
|
||||||
signal.signal(signal.SIGALRM, self.handle_timeout)
|
signal.signal(signal.SIGALRM, self.handle_timeout)
|
||||||
signal.alarm(self.seconds)
|
signal.alarm(self.seconds)
|
||||||
|
|
@ -613,7 +607,7 @@ class timeout:
|
||||||
logger.warning("timeout can't be used in the current context")
|
logger.warning("timeout can't be used in the current context")
|
||||||
logger.exception(ex)
|
logger.exception(ex)
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type: Any, value: Any, traceback: TracebackType) -> None:
|
||||||
try:
|
try:
|
||||||
signal.alarm(0)
|
signal.alarm(0)
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
|
|
@ -621,9 +615,9 @@ class timeout:
|
||||||
logger.exception(ex)
|
logger.exception(ex)
|
||||||
|
|
||||||
|
|
||||||
def pessimistic_connection_handling(some_engine):
|
def pessimistic_connection_handling(some_engine: Engine) -> None:
|
||||||
@event.listens_for(some_engine, "engine_connect")
|
@event.listens_for(some_engine, "engine_connect")
|
||||||
def ping_connection(connection, branch):
|
def ping_connection(connection: Connection, branch: bool) -> None:
|
||||||
if branch:
|
if branch:
|
||||||
# 'branch' refers to a sub-connection of a connection,
|
# 'branch' refers to a sub-connection of a connection,
|
||||||
# we don't want to bother pinging on these.
|
# we don't want to bother pinging on these.
|
||||||
|
|
@ -670,7 +664,14 @@ class QueryStatus:
|
||||||
TIMED_OUT: str = "timed_out"
|
TIMED_OUT: str = "timed_out"
|
||||||
|
|
||||||
|
|
||||||
def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, config):
|
def notify_user_about_perm_udate(
|
||||||
|
granter: User,
|
||||||
|
user: User,
|
||||||
|
role: Role,
|
||||||
|
datasource: "BaseDatasource",
|
||||||
|
tpl_name: str,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
msg = render_template(
|
msg = render_template(
|
||||||
tpl_name, granter=granter, user=user, role=role, datasource=datasource
|
tpl_name, granter=granter, user=user, role=role, datasource=datasource
|
||||||
)
|
)
|
||||||
|
|
@ -762,7 +763,13 @@ def send_email_smtp(
|
||||||
send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun)
|
send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun)
|
||||||
|
|
||||||
|
|
||||||
def send_MIME_email(e_from, e_to, mime_msg, config, dryrun=False):
|
def send_MIME_email(
|
||||||
|
e_from: str,
|
||||||
|
e_to: List[str],
|
||||||
|
mime_msg: MIMEMultipart,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
dryrun: bool = False,
|
||||||
|
) -> None:
|
||||||
SMTP_HOST = config["SMTP_HOST"]
|
SMTP_HOST = config["SMTP_HOST"]
|
||||||
SMTP_PORT = config["SMTP_PORT"]
|
SMTP_PORT = config["SMTP_PORT"]
|
||||||
SMTP_USER = config["SMTP_USER"]
|
SMTP_USER = config["SMTP_USER"]
|
||||||
|
|
@ -800,7 +807,7 @@ def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]:
|
||||||
return [(v, v) for v in values]
|
return [(v, v) for v in values]
|
||||||
|
|
||||||
|
|
||||||
def zlib_compress(data):
|
def zlib_compress(data: Union[bytes, str]) -> bytes:
|
||||||
"""
|
"""
|
||||||
Compress things in a py2/3 safe fashion
|
Compress things in a py2/3 safe fashion
|
||||||
>>> json_str = '{"test": 1}'
|
>>> json_str = '{"test": 1}'
|
||||||
|
|
@ -827,7 +834,9 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes,
|
||||||
return decompressed.decode("utf-8") if decode else decompressed
|
return decompressed.decode("utf-8") if decode else decompressed
|
||||||
|
|
||||||
|
|
||||||
def to_adhoc(filt, expressionType="SIMPLE", clause="where"):
|
def to_adhoc(
|
||||||
|
filt: Dict[str, Any], expressionType: str = "SIMPLE", clause: str = "where"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
result = {
|
result = {
|
||||||
"clause": clause.upper(),
|
"clause": clause.upper(),
|
||||||
"expressionType": expressionType,
|
"expressionType": expressionType,
|
||||||
|
|
@ -849,7 +858,7 @@ def to_adhoc(filt, expressionType="SIMPLE", clause="where"):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def merge_extra_filters(form_data: dict):
|
def merge_extra_filters(form_data: Dict[str, Any]) -> None:
|
||||||
# extra_filters are temporary/contextual filters (using the legacy constructs)
|
# extra_filters are temporary/contextual filters (using the legacy constructs)
|
||||||
# that are external to the slice definition. We use those for dynamic
|
# that are external to the slice definition. We use those for dynamic
|
||||||
# interactive filters like the ones emitted by the "Filter Box" visualization.
|
# interactive filters like the ones emitted by the "Filter Box" visualization.
|
||||||
|
|
@ -872,7 +881,7 @@ def merge_extra_filters(form_data: dict):
|
||||||
}
|
}
|
||||||
# Grab list of existing filters 'keyed' on the column and operator
|
# Grab list of existing filters 'keyed' on the column and operator
|
||||||
|
|
||||||
def get_filter_key(f):
|
def get_filter_key(f: Dict[str, Any]) -> str:
|
||||||
if "expressionType" in f:
|
if "expressionType" in f:
|
||||||
return "{}__{}".format(f["subject"], f["operator"])
|
return "{}__{}".format(f["subject"], f["operator"])
|
||||||
else:
|
else:
|
||||||
|
|
@ -945,7 +954,9 @@ def user_label(user: User) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs):
|
def get_or_create_db(
|
||||||
|
database_name: str, sqlalchemy_uri: str, *args: Any, **kwargs: Any
|
||||||
|
) -> "Database":
|
||||||
from superset import db
|
from superset import db
|
||||||
from superset.models import core as models
|
from superset.models import core as models
|
||||||
|
|
||||||
|
|
@ -996,7 +1007,7 @@ def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
|
||||||
return [get_metric_name(metric) for metric in metrics]
|
return [get_metric_name(metric) for metric in metrics]
|
||||||
|
|
||||||
|
|
||||||
def ensure_path_exists(path: str):
|
def ensure_path_exists(path: str) -> None:
|
||||||
try:
|
try:
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
|
|
@ -1119,7 +1130,7 @@ def add_ago_to_since(since: str) -> str:
|
||||||
return since
|
return since
|
||||||
|
|
||||||
|
|
||||||
def convert_legacy_filters_into_adhoc(fd):
|
def convert_legacy_filters_into_adhoc(fd: FormData) -> None:
|
||||||
mapping = {"having": "having_filters", "where": "filters"}
|
mapping = {"having": "having_filters", "where": "filters"}
|
||||||
|
|
||||||
if not fd.get("adhoc_filters"):
|
if not fd.get("adhoc_filters"):
|
||||||
|
|
@ -1138,7 +1149,7 @@ def convert_legacy_filters_into_adhoc(fd):
|
||||||
del fd[key]
|
del fd[key]
|
||||||
|
|
||||||
|
|
||||||
def split_adhoc_filters_into_base_filters(fd):
|
def split_adhoc_filters_into_base_filters(fd: FormData) -> None:
|
||||||
"""
|
"""
|
||||||
Mutates form data to restructure the adhoc filters in the form of the four base
|
Mutates form data to restructure the adhoc filters in the form of the four base
|
||||||
filters, `where`, `having`, `filters`, and `having_filters` which represent
|
filters, `where`, `having`, `filters`, and `having_filters` which represent
|
||||||
|
|
@ -1230,7 +1241,7 @@ def create_ssl_cert_file(certificate: str) -> str:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def time_function(func: Callable, *args, **kwargs) -> Tuple[float, Any]:
|
def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]:
|
||||||
"""
|
"""
|
||||||
Measures the amount of time a function takes to execute in ms
|
Measures the amount of time a function takes to execute in ms
|
||||||
|
|
||||||
|
|
@ -1296,7 +1307,7 @@ def split(
|
||||||
yield s[i:]
|
yield s[i:]
|
||||||
|
|
||||||
|
|
||||||
def get_iterable(x: Any) -> List:
|
def get_iterable(x: Any) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Get an iterable (list) representation of the object.
|
Get an iterable (list) representation of the object.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,16 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from superset.models.slice import Slice
|
from superset.models.slice import Slice
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
|
def convert_filter_scopes(
|
||||||
|
json_metadata: Dict[Any, Any], filters: List[Slice]
|
||||||
|
) -> Dict[int, Dict[str, Dict[str, Any]]]:
|
||||||
filter_scopes = {}
|
filter_scopes = {}
|
||||||
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
|
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
|
||||||
immuned_by_column: Dict = defaultdict(list)
|
immuned_by_column: Dict = defaultdict(list)
|
||||||
|
|
@ -34,7 +36,9 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
|
||||||
for column in columns:
|
for column in columns:
|
||||||
immuned_by_column[column].append(int(slice_id))
|
immuned_by_column[column].append(int(slice_id))
|
||||||
|
|
||||||
def add_filter_scope(filter_field, filter_id):
|
def add_filter_scope(
|
||||||
|
filter_fields: Dict[str, Dict[str, Any]], filter_field: str, filter_id: int
|
||||||
|
) -> None:
|
||||||
# in case filter field is invalid
|
# in case filter field is invalid
|
||||||
if isinstance(filter_field, str):
|
if isinstance(filter_field, str):
|
||||||
current_filter_immune = list(
|
current_filter_immune = list(
|
||||||
|
|
@ -54,17 +58,17 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]):
|
||||||
configs = slice_params.get("filter_configs") or []
|
configs = slice_params.get("filter_configs") or []
|
||||||
|
|
||||||
if slice_params.get("date_filter"):
|
if slice_params.get("date_filter"):
|
||||||
add_filter_scope("__time_range", filter_id)
|
add_filter_scope(filter_fields, "__time_range", filter_id)
|
||||||
if slice_params.get("show_sqla_time_column"):
|
if slice_params.get("show_sqla_time_column"):
|
||||||
add_filter_scope("__time_col", filter_id)
|
add_filter_scope(filter_fields, "__time_col", filter_id)
|
||||||
if slice_params.get("show_sqla_time_granularity"):
|
if slice_params.get("show_sqla_time_granularity"):
|
||||||
add_filter_scope("__time_grain", filter_id)
|
add_filter_scope(filter_fields, "__time_grain", filter_id)
|
||||||
if slice_params.get("show_druid_time_granularity"):
|
if slice_params.get("show_druid_time_granularity"):
|
||||||
add_filter_scope("__granularity", filter_id)
|
add_filter_scope(filter_fields, "__granularity", filter_id)
|
||||||
if slice_params.get("show_druid_time_origin"):
|
if slice_params.get("show_druid_time_origin"):
|
||||||
add_filter_scope("druid_time_origin", filter_id)
|
add_filter_scope(filter_fields, "druid_time_origin", filter_id)
|
||||||
for config in configs:
|
for config in configs:
|
||||||
add_filter_scope(config.get("column"), filter_id)
|
add_filter_scope(filter_fields, config.get("column"), filter_id)
|
||||||
|
|
||||||
if filter_fields:
|
if filter_fields:
|
||||||
filter_scopes[filter_id] = filter_fields
|
filter_scopes[filter_id] = filter_fields
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,10 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||||
from superset.models.dashboard import Dashboard
|
from superset.models.dashboard import Dashboard
|
||||||
|
|
@ -27,7 +31,7 @@ from superset.models.slice import Slice
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def decode_dashboards(o):
|
def decode_dashboards(o: Dict[str, Any]) -> Any:
|
||||||
"""
|
"""
|
||||||
Function to be passed into json.loads obj_hook parameter
|
Function to be passed into json.loads obj_hook parameter
|
||||||
Recreates the dashboard object from a json representation.
|
Recreates the dashboard object from a json representation.
|
||||||
|
|
@ -50,7 +54,9 @@ def decode_dashboards(o):
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
def import_dashboards(session, data_stream, import_time=None):
|
def import_dashboards(
|
||||||
|
session: Session, data_stream: BytesIO, import_time: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
"""Imports dashboards from a stream to databases"""
|
"""Imports dashboards from a stream to databases"""
|
||||||
current_tt = int(time.time())
|
current_tt = int(time.time())
|
||||||
import_time = current_tt if import_time is None else import_time
|
import_time = current_tt if import_time is None else import_time
|
||||||
|
|
@ -64,7 +70,7 @@ def import_dashboards(session, data_stream, import_time=None):
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
def export_dashboards(session):
|
def export_dashboards(session: Session) -> str:
|
||||||
"""Returns all dashboards metadata as a json dump"""
|
"""Returns all dashboards metadata as a json dump"""
|
||||||
logger.info("Starting export")
|
logger.info("Starting export")
|
||||||
dashboards = session.query(Dashboard)
|
dashboards = session.query(Dashboard)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ import pytz
|
||||||
EPOCH = datetime(1970, 1, 1)
|
EPOCH = datetime(1970, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
def datetime_to_epoch(dttm):
|
def datetime_to_epoch(dttm: datetime) -> float:
|
||||||
if dttm.tzinfo:
|
if dttm.tzinfo:
|
||||||
dttm = dttm.replace(tzinfo=pytz.utc)
|
dttm = dttm.replace(tzinfo=pytz.utc)
|
||||||
epoch_with_tz = pytz.utc.localize(EPOCH)
|
epoch_with_tz = pytz.utc.localize(EPOCH)
|
||||||
|
|
@ -29,5 +29,5 @@ def datetime_to_epoch(dttm):
|
||||||
return (dttm - EPOCH).total_seconds() * 1000
|
return (dttm - EPOCH).total_seconds() * 1000
|
||||||
|
|
||||||
|
|
||||||
def now_as_float():
|
def now_as_float() -> float:
|
||||||
return datetime_to_epoch(datetime.utcnow())
|
return datetime_to_epoch(datetime.utcnow())
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,14 @@
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Any, Callable, Iterator
|
||||||
|
|
||||||
from contextlib2 import contextmanager
|
from contextlib2 import contextmanager
|
||||||
from flask import request
|
from flask import request
|
||||||
|
from werkzeug.wrappers.etag import ETagResponseMixin
|
||||||
|
|
||||||
from superset import app, cache
|
from superset import app, cache
|
||||||
|
from superset.stats_logger import BaseStatsLogger
|
||||||
from superset.utils.dates import now_as_float
|
from superset.utils.dates import now_as_float
|
||||||
|
|
||||||
# If a user sets `max_age` to 0, for long the browser should cache the
|
# If a user sets `max_age` to 0, for long the browser should cache the
|
||||||
|
|
@ -32,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def stats_timing(stats_key, stats_logger):
|
def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[float]:
|
||||||
"""Provide a transactional scope around a series of operations."""
|
"""Provide a transactional scope around a series of operations."""
|
||||||
start_ts = now_as_float()
|
start_ts = now_as_float()
|
||||||
try:
|
try:
|
||||||
|
|
@ -43,7 +46,7 @@ def stats_timing(stats_key, stats_logger):
|
||||||
stats_logger.timing(stats_key, now_as_float() - start_ts)
|
stats_logger.timing(stats_key, now_as_float() - start_ts)
|
||||||
|
|
||||||
|
|
||||||
def etag_cache(max_age, check_perms=bool):
|
def etag_cache(max_age: int, check_perms: Callable) -> Callable:
|
||||||
"""
|
"""
|
||||||
A decorator for caching views and handling etag conditional requests.
|
A decorator for caching views and handling etag conditional requests.
|
||||||
|
|
||||||
|
|
@ -57,9 +60,9 @@ def etag_cache(max_age, check_perms=bool):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(f):
|
def decorator(f: Callable) -> Callable:
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
|
||||||
# check if the user can access the resource
|
# check if the user can access the resource
|
||||||
check_perms(*args, **kwargs)
|
check_perms(*args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -77,7 +80,9 @@ def etag_cache(max_age, check_perms=bool):
|
||||||
key_args = list(args)
|
key_args = list(args)
|
||||||
key_kwargs = kwargs.copy()
|
key_kwargs = kwargs.copy()
|
||||||
key_kwargs.update(request.args)
|
key_kwargs.update(request.args)
|
||||||
cache_key = wrapper.make_cache_key(f, *key_args, **key_kwargs)
|
cache_key = wrapper.make_cache_key( # type: ignore
|
||||||
|
f, *key_args, **key_kwargs
|
||||||
|
)
|
||||||
response = cache.get(cache_key)
|
response = cache.get(cache_key)
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
if app.debug:
|
if app.debug:
|
||||||
|
|
@ -109,9 +114,9 @@ def etag_cache(max_age, check_perms=bool):
|
||||||
return response.make_conditional(request)
|
return response.make_conditional(request)
|
||||||
|
|
||||||
if cache:
|
if cache:
|
||||||
wrapper.uncached = f
|
wrapper.uncached = f # type: ignore
|
||||||
wrapper.cache_timeout = max_age
|
wrapper.cache_timeout = max_age # type: ignore
|
||||||
wrapper.make_cache_key = cache._memoize_make_cache_key( # pylint: disable=protected-access
|
wrapper.make_cache_key = cache._memoize_make_cache_key( # type: ignore # pylint: disable=protected-access
|
||||||
make_name=None, timeout=max_age
|
make_name=None, timeout=max_age
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,9 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
# pylint: disable=C,R,W
|
# pylint: disable=C,R,W
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from superset.connectors.druid.models import DruidCluster
|
from superset.connectors.druid.models import DruidCluster
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
|
|
@ -25,7 +28,7 @@ DRUID_CLUSTERS_KEY = "druid_clusters"
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def export_schema_to_dict(back_references):
|
def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
|
||||||
"""Exports the supported import/export schema to a dictionary"""
|
"""Exports the supported import/export schema to a dictionary"""
|
||||||
databases = [
|
databases = [
|
||||||
Database.export_schema(recursive=True, include_parent_ref=back_references)
|
Database.export_schema(recursive=True, include_parent_ref=back_references)
|
||||||
|
|
@ -41,7 +44,9 @@ def export_schema_to_dict(back_references):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def export_to_dict(session, recursive, back_references, include_defaults):
|
def export_to_dict(
|
||||||
|
session: Session, recursive: bool, back_references: bool, include_defaults: bool
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Exports databases and druid clusters to a dictionary"""
|
"""Exports databases and druid clusters to a dictionary"""
|
||||||
logger.info("Starting export")
|
logger.info("Starting export")
|
||||||
dbs = session.query(Database)
|
dbs = session.query(Database)
|
||||||
|
|
@ -72,8 +77,12 @@ def export_to_dict(session, recursive, back_references, include_defaults):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def import_from_dict(session, data, sync=[]):
|
def import_from_dict(
|
||||||
|
session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
|
||||||
|
) -> None:
|
||||||
"""Imports databases and druid clusters from dictionary"""
|
"""Imports databases and druid clusters from dictionary"""
|
||||||
|
if not sync:
|
||||||
|
sync = []
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
|
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
|
||||||
for database in data.get(DATABASES_KEY, []):
|
for database in data.get(DATABASES_KEY, []):
|
||||||
|
|
|
||||||
|
|
@ -15,25 +15,33 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
|
||||||
class FeatureFlagManager:
|
class FeatureFlagManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._get_feature_flags_func = None
|
self._get_feature_flags_func = None
|
||||||
self._feature_flags = None
|
self._feature_flags: Dict[str, Any] = {}
|
||||||
|
|
||||||
def init_app(self, app):
|
def init_app(self, app: Flask) -> None:
|
||||||
self._get_feature_flags_func = app.config.get("GET_FEATURE_FLAGS_FUNC")
|
self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"]
|
||||||
self._feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {}
|
self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"]
|
||||||
self._feature_flags.update(app.config.get("FEATURE_FLAGS") or {})
|
self._feature_flags.update(app.config["FEATURE_FLAGS"])
|
||||||
|
|
||||||
def get_feature_flags(self):
|
def get_feature_flags(self) -> Dict[str, Any]:
|
||||||
if self._get_feature_flags_func:
|
if self._get_feature_flags_func:
|
||||||
return self._get_feature_flags_func(deepcopy(self._feature_flags))
|
return self._get_feature_flags_func(deepcopy(self._feature_flags))
|
||||||
|
|
||||||
return self._feature_flags
|
return self._feature_flags
|
||||||
|
|
||||||
def is_feature_enabled(self, feature) -> bool:
|
def is_feature_enabled(self, feature: str) -> bool:
|
||||||
"""Utility function for checking whether a feature is turned on"""
|
"""Utility function for checking whether a feature is turned on"""
|
||||||
return self.get_feature_flags().get(feature)
|
feature_flags = self.get_feature_flags()
|
||||||
|
|
||||||
|
if feature_flags and feature in feature_flags:
|
||||||
|
return feature_flags[feature]
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
|
||||||
|
|
@ -21,19 +21,23 @@ import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, cast, Type
|
from typing import Any, Callable, cast, Optional, Type
|
||||||
|
|
||||||
from flask import current_app, g, request
|
from flask import current_app, g, request
|
||||||
|
|
||||||
|
from superset.stats_logger import BaseStatsLogger
|
||||||
|
|
||||||
|
|
||||||
class AbstractEventLogger(ABC):
|
class AbstractEventLogger(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def log(self, user_id, action, *args, **kwargs):
|
def log(
|
||||||
|
self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def log_this(self, f):
|
def log_this(self, f: Callable) -> Callable:
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
user_id = None
|
user_id = None
|
||||||
if g.user:
|
if g.user:
|
||||||
user_id = g.user.get_id()
|
user_id = g.user.get_id()
|
||||||
|
|
@ -49,7 +53,12 @@ class AbstractEventLogger(ABC):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
slice_id = int(
|
slice_id = int(
|
||||||
slice_id or json.loads(form_data.get("form_data")).get("slice_id")
|
slice_id
|
||||||
|
or json.loads(
|
||||||
|
form_data.get("form_data") # type: ignore
|
||||||
|
).get(
|
||||||
|
"slice_id"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
slice_id = 0
|
slice_id = 0
|
||||||
|
|
@ -62,7 +71,7 @@ class AbstractEventLogger(ABC):
|
||||||
# bulk insert
|
# bulk insert
|
||||||
try:
|
try:
|
||||||
explode_by = form_data.get("explode")
|
explode_by = form_data.get("explode")
|
||||||
records = json.loads(form_data.get(explode_by))
|
records = json.loads(form_data.get(explode_by)) # type: ignore
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
records = [form_data]
|
records = [form_data]
|
||||||
|
|
||||||
|
|
@ -82,11 +91,11 @@ class AbstractEventLogger(ABC):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stats_logger(self):
|
def stats_logger(self) -> BaseStatsLogger:
|
||||||
return current_app.config["STATS_LOGGER"]
|
return current_app.config["STATS_LOGGER"]
|
||||||
|
|
||||||
|
|
||||||
def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger:
|
def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
|
||||||
"""
|
"""
|
||||||
This function implements the deprecation of assignment
|
This function implements the deprecation of assignment
|
||||||
of class objects to EVENT_LOGGER configuration, and validates
|
of class objects to EVENT_LOGGER configuration, and validates
|
||||||
|
|
@ -130,7 +139,9 @@ def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger:
|
||||||
|
|
||||||
|
|
||||||
class DBEventLogger(AbstractEventLogger):
|
class DBEventLogger(AbstractEventLogger):
|
||||||
def log(self, user_id, action, *args, **kwargs): # pylint: disable=too-many-locals
|
def log( # pylint: disable=too-many-locals
|
||||||
|
self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
from superset.models.core import Log
|
from superset.models.core import Log
|
||||||
|
|
||||||
records = kwargs.get("records", list())
|
records = kwargs.get("records", list())
|
||||||
|
|
@ -141,6 +152,7 @@ class DBEventLogger(AbstractEventLogger):
|
||||||
|
|
||||||
logs = list()
|
logs = list()
|
||||||
for record in records:
|
for record in records:
|
||||||
|
json_string: Optional[str]
|
||||||
try:
|
try:
|
||||||
json_string = json.dumps(record)
|
json_string = json.dumps(record)
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
|
|
|
||||||
|
|
@ -73,8 +73,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
|
||||||
|
|
||||||
|
|
||||||
def validate_column_args(*argnames: str) -> Callable:
|
def validate_column_args(*argnames: str) -> Callable:
|
||||||
def wrapper(func):
|
def wrapper(func: Callable) -> Callable:
|
||||||
def wrapped(df, **options):
|
def wrapped(df: DataFrame, **options: Any) -> Any:
|
||||||
columns = df.columns.tolist()
|
columns = df.columns.tolist()
|
||||||
for name in argnames:
|
for name in argnames:
|
||||||
if name in options and not all(
|
if name in options and not all(
|
||||||
|
|
@ -159,7 +159,7 @@ def pivot( # pylint: disable=too-many-arguments
|
||||||
metric_fill_value: Optional[Any] = None,
|
metric_fill_value: Optional[Any] = None,
|
||||||
column_fill_value: Optional[str] = None,
|
column_fill_value: Optional[str] = None,
|
||||||
drop_missing_columns: Optional[bool] = True,
|
drop_missing_columns: Optional[bool] = True,
|
||||||
combine_value_with_metric=False,
|
combine_value_with_metric: bool = False,
|
||||||
marginal_distributions: Optional[bool] = None,
|
marginal_distributions: Optional[bool] = None,
|
||||||
marginal_distribution_name: Optional[str] = None,
|
marginal_distribution_name: Optional[str] = None,
|
||||||
) -> DataFrame:
|
) -> DataFrame:
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from flask import current_app, request, Response, session, url_for
|
from flask import current_app, request, Response, session, url_for
|
||||||
from flask_login import login_user
|
from flask_login import login_user
|
||||||
|
|
@ -91,7 +91,7 @@ def headless_url(path: str) -> str:
|
||||||
return urllib.parse.urljoin(current_app.config.get("WEBDRIVER_BASEURL", ""), path)
|
return urllib.parse.urljoin(current_app.config.get("WEBDRIVER_BASEURL", ""), path)
|
||||||
|
|
||||||
|
|
||||||
def get_url_path(view: str, **kwargs) -> str:
|
def get_url_path(view: str, **kwargs: Any) -> str:
|
||||||
with current_app.test_request_context():
|
with current_app.test_request_context():
|
||||||
return headless_url(url_for(view, **kwargs))
|
return headless_url(url_for(view, **kwargs))
|
||||||
|
|
||||||
|
|
@ -135,7 +135,7 @@ class AuthWebDriverProxy:
|
||||||
return self._auth_func(driver, user)
|
return self._auth_func(driver, user)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def destroy(driver: WebDriver, tries=2):
|
def destroy(driver: WebDriver, tries: int = 2) -> None:
|
||||||
"""Destroy a driver"""
|
"""Destroy a driver"""
|
||||||
# This is some very flaky code in selenium. Hence the retries
|
# This is some very flaky code in selenium. Hence the retries
|
||||||
# and catch-all exceptions
|
# and catch-all exceptions
|
||||||
|
|
|
||||||
|
|
@ -14,22 +14,24 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from werkzeug.routing import BaseConverter
|
from typing import Any, List
|
||||||
|
|
||||||
|
from werkzeug.routing import BaseConverter, Map
|
||||||
|
|
||||||
from superset.models.tags import ObjectTypes
|
from superset.models.tags import ObjectTypes
|
||||||
|
|
||||||
|
|
||||||
class RegexConverter(BaseConverter):
|
class RegexConverter(BaseConverter):
|
||||||
def __init__(self, url_map, *items):
|
def __init__(self, url_map: Map, *items: List[str]) -> None:
|
||||||
super(RegexConverter, self).__init__(url_map)
|
super(RegexConverter, self).__init__(url_map) # type: ignore
|
||||||
self.regex = items[0]
|
self.regex = items[0]
|
||||||
|
|
||||||
|
|
||||||
class ObjectTypeConverter(BaseConverter):
|
class ObjectTypeConverter(BaseConverter):
|
||||||
"""Validate that object_type is indeed an object type."""
|
"""Validate that object_type is indeed an object type."""
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value: str) -> Any:
|
||||||
return ObjectTypes[value]
|
return ObjectTypes[value]
|
||||||
|
|
||||||
def to_url(self, value):
|
def to_url(self, value: Any) -> str:
|
||||||
return value.name
|
return value.name
|
||||||
|
|
|
||||||
|
|
@ -2164,7 +2164,7 @@ class Superset(BaseSupersetView):
|
||||||
return json_error_response(str(ex))
|
return json_error_response(str(ex))
|
||||||
|
|
||||||
spec = mydb.db_engine_spec
|
spec = mydb.db_engine_spec
|
||||||
query_cost_formatters = get_feature_flags().get(
|
query_cost_formatters: Dict[str, Any] = get_feature_flags().get(
|
||||||
"QUERY_COST_FORMATTERS_BY_ENGINE", {}
|
"QUERY_COST_FORMATTERS_BY_ENGINE", {}
|
||||||
)
|
)
|
||||||
query_cost_formatter = query_cost_formatters.get(
|
query_cost_formatter = query_cost_formatters.get(
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,6 @@ from superset.utils.core import (
|
||||||
base_json_conv,
|
base_json_conv,
|
||||||
convert_legacy_filters_into_adhoc,
|
convert_legacy_filters_into_adhoc,
|
||||||
create_ssl_cert_file,
|
create_ssl_cert_file,
|
||||||
datetime_f,
|
|
||||||
format_timedelta,
|
format_timedelta,
|
||||||
get_iterable,
|
get_iterable,
|
||||||
get_email_address_list,
|
get_email_address_list,
|
||||||
|
|
@ -560,17 +559,6 @@ class UtilsTestCase(SupersetTestCase):
|
||||||
url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"]
|
url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_datetime_f(self):
|
|
||||||
self.assertEqual(
|
|
||||||
datetime_f(datetime(1990, 9, 21, 19, 11, 19, 626096)),
|
|
||||||
"<nobr>1990-09-21T19:11:19.626096</nobr>",
|
|
||||||
)
|
|
||||||
self.assertEqual(len(datetime_f(datetime.now())), 28)
|
|
||||||
self.assertEqual(datetime_f(None), "<nobr>None</nobr>")
|
|
||||||
iso = datetime.now().isoformat()[:10].split("-")
|
|
||||||
[a, b, c] = [int(v) for v in iso]
|
|
||||||
self.assertEqual(datetime_f(datetime(a, b, c)), "<nobr>00:00:00</nobr>")
|
|
||||||
|
|
||||||
def test_format_timedelta(self):
|
def test_format_timedelta(self):
|
||||||
self.assertEqual(format_timedelta(timedelta(0)), "0:00:00")
|
self.assertEqual(format_timedelta(timedelta(0)), "0:00:00")
|
||||||
self.assertEqual(format_timedelta(timedelta(days=1)), "1 day, 0:00:00")
|
self.assertEqual(format_timedelta(timedelta(days=1)), "1 day, 0:00:00")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue