style(mypy): Spit-and-polish pass (#10001)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-06-07 08:53:46 -07:00 committed by GitHub
parent 656cdfb867
commit 91517a56a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 243 additions and 207 deletions

View File

@ -50,10 +50,12 @@ multi_line_output = 3
order_by_type = false order_by_type = false
[mypy] [mypy]
disallow_any_generics = true
ignore_missing_imports = true ignore_missing_imports = true
no_implicit_optional = true no_implicit_optional = true
warn_unused_ignores = true
[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,,superset.views.*,superset.viz,superset.viz_sip38] [mypy-superset.*]
check_untyped_defs = true check_untyped_defs = true
disallow_untyped_calls = true disallow_untyped_calls = true
disallow_untyped_defs = true disallow_untyped_defs = true

View File

@ -80,7 +80,7 @@ class SupersetAppInitializer:
self.flask_app = app self.flask_app = app
self.config = app.config self.config = app.config
self.manifest: dict = {} self.manifest: Dict[Any, Any] = {}
def pre_init(self) -> None: def pre_init(self) -> None:
""" """
@ -542,7 +542,7 @@ class SupersetAppInitializer:
self.app = app self.app = app
def __call__( def __call__(
self, environ: Dict[str, Any], start_response: Callable self, environ: Dict[str, Any], start_response: Callable[..., Any]
) -> Any: ) -> Any:
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore # Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end. # content-length and read the stream till the end.

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
class CreateChartCommand(BaseCommand): class CreateChartCommand(BaseCommand):
def __init__(self, user: User, data: Dict): def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user self._actor = user
self._properties = data.copy() self._properties = data.copy()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
class UpdateChartCommand(BaseCommand): class UpdateChartCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict): def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user self._actor = user
self._model_id = model_id self._model_id = model_id
self._properties = data.copy() self._properties = data.copy()

View File

@ -26,6 +26,7 @@ from pandas import DataFrame
from superset import app, is_feature_enabled from superset import app, is_feature_enabled
from superset.exceptions import QueryObjectValidationError from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric
from superset.utils import core as utils, pandas_postprocessing from superset.utils import core as utils, pandas_postprocessing
from superset.views.utils import get_time_range_endpoints from superset.views.utils import get_time_range_endpoints
@ -67,11 +68,11 @@ class QueryObject:
row_limit: int row_limit: int
filter: List[Dict[str, Any]] filter: List[Dict[str, Any]]
timeseries_limit: int timeseries_limit: int
timeseries_limit_metric: Optional[Dict] timeseries_limit_metric: Optional[Metric]
order_desc: bool order_desc: bool
extras: Dict extras: Dict[str, Any]
columns: List[str] columns: List[str]
orderby: List[List] orderby: List[List[str]]
post_processing: List[Dict[str, Any]] post_processing: List[Dict[str, Any]]
def __init__( def __init__(
@ -85,11 +86,11 @@ class QueryObject:
is_timeseries: bool = False, is_timeseries: bool = False,
timeseries_limit: int = 0, timeseries_limit: int = 0,
row_limit: int = app.config["ROW_LIMIT"], row_limit: int = app.config["ROW_LIMIT"],
timeseries_limit_metric: Optional[Dict] = None, timeseries_limit_metric: Optional[Metric] = None,
order_desc: bool = True, order_desc: bool = True,
extras: Optional[Dict] = None, extras: Optional[Dict[str, Any]] = None,
columns: Optional[List[str]] = None, columns: Optional[List[str]] = None,
orderby: Optional[List[List]] = None, orderby: Optional[List[List[str]]] = None,
post_processing: Optional[List[Dict[str, Any]]] = None, post_processing: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any, **kwargs: Any,
): ):

View File

@ -33,6 +33,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
from cachelib.base import BaseCache from cachelib.base import BaseCache
from celery.schedules import crontab from celery.schedules import crontab
from dateutil import tz from dateutil import tz
from flask import Blueprint
from flask_appbuilder.security.manager import AUTH_DB from flask_appbuilder.security.manager import AUTH_DB
from superset.jinja_context import ( # pylint: disable=unused-import from superset.jinja_context import ( # pylint: disable=unused-import
@ -421,7 +422,7 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
] ]
) )
ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {} ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
ADDITIONAL_MIDDLEWARE: List[Callable] = [] ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
# 1) https://docs.python-guide.org/writing/logging/ # 1) https://docs.python-guide.org/writing/logging/
# 2) https://docs.python.org/2/library/logging.config.html # 2) https://docs.python.org/2/library/logging.config.html
@ -624,7 +625,7 @@ ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
# SQL Lab. The existing context gets updated with this dictionary, # SQL Lab. The existing context gets updated with this dictionary,
# meaning values for existing keys get overwritten by the content of this # meaning values for existing keys get overwritten by the content of this
# dictionary. # dictionary.
JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {} JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
# A dictionary of macro template processors that gets merged into global # A dictionary of macro template processors that gets merged into global
# template processors. The existing template processors get updated with this # template processors. The existing template processors get updated with this
@ -684,7 +685,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
# Integrate external Blueprints to the app by passing them to your # Integrate external Blueprints to the app by passing them to your
# configuration. These blueprints will get integrated in the app # configuration. These blueprints will get integrated in the app
BLUEPRINTS: List[Callable] = [] BLUEPRINTS: List[Blueprint] = []
# Provide a callable that receives a tracking_url and returns another # Provide a callable that receives a tracking_url and returns another
# URL. This is used to translate internal Hadoop job tracker URL # URL. This is used to translate internal Hadoop job tracker URL

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import json import json
from typing import Any, Dict, Hashable, List, Optional, Type from typing import Any, Dict, Hashable, List, Optional, Type, Union
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, Boolean, Column, Integer, String, Text from sqlalchemy import and_, Boolean, Column, Integer, String, Text
@ -64,12 +64,12 @@ class BaseDatasource(
baselink: Optional[str] = None # url portion pointing to ModelView endpoint baselink: Optional[str] = None # url portion pointing to ModelView endpoint
@property @property
def column_class(self) -> Type: def column_class(self) -> Type["BaseColumn"]:
# link to derivative of BaseColumn # link to derivative of BaseColumn
raise NotImplementedError() raise NotImplementedError()
@property @property
def metric_class(self) -> Type: def metric_class(self) -> Type["BaseMetric"]:
# link to derivative of BaseMetric # link to derivative of BaseMetric
raise NotImplementedError() raise NotImplementedError()
@ -368,7 +368,7 @@ class BaseDatasource(
""" """
raise NotImplementedError() raise NotImplementedError()
def values_for_column(self, column_name: str, limit: int = 10000) -> List: def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
"""Given a column, returns an iterable of distinct values """Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of This is used to populate the dropdown showing a list of
@ -389,7 +389,10 @@ class BaseDatasource(
@staticmethod @staticmethod
def get_fk_many_from_list( def get_fk_many_from_list(
object_list: List[Any], fkmany: List[Column], fkmany_class: Type, key_attr: str, object_list: List[Any],
fkmany: List[Column],
fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
key_attr: str,
) -> List[Column]: # pylint: disable=too-many-locals ) -> List[Column]: # pylint: disable=too-many-locals
"""Update ORM one-to-many list from object list """Update ORM one-to-many list from object list

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from collections import OrderedDict
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
from sqlalchemy import or_ from sqlalchemy import or_
@ -22,6 +21,8 @@ from sqlalchemy.orm import Session, subqueryload
if TYPE_CHECKING: if TYPE_CHECKING:
# pylint: disable=unused-import # pylint: disable=unused-import
from collections import OrderedDict
from superset.models.core import Database from superset.models.core import Database
from superset.connectors.base.models import BaseDatasource from superset.connectors.base.models import BaseDatasource
@ -32,7 +33,7 @@ class ConnectorRegistry:
sources: Dict[str, Type["BaseDatasource"]] = {} sources: Dict[str, Type["BaseDatasource"]] = {}
@classmethod @classmethod
def register_sources(cls, datasource_config: OrderedDict) -> None: def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> None:
for module_name, class_names in datasource_config.items(): for module_name, class_names in datasource_config.items():
class_names = [str(s) for s in class_names] class_names = [str(s) for s in class_names]
module_obj = __import__(module_name, fromlist=class_names) module_obj = __import__(module_name, fromlist=class_names)

View File

@ -24,18 +24,7 @@ from copy import deepcopy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from distutils.version import LooseVersion from distutils.version import LooseVersion
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from typing import ( from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union
Any,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
import pandas as pd import pandas as pd
import sqlalchemy as sa import sqlalchemy as sa
@ -173,7 +162,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
return self.__repr__() return self.__repr__()
@property @property
def data(self) -> Dict: def data(self) -> Dict[str, Any]:
return {"id": self.id, "name": self.cluster_name, "backend": "druid"} return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
@staticmethod @staticmethod
@ -354,7 +343,7 @@ class DruidColumn(Model, BaseColumn):
return self.dimension_spec_json return self.dimension_spec_json
@property @property
def dimension_spec(self) -> Optional[Dict]: def dimension_spec(self) -> Optional[Dict[str, Any]]:
if self.dimension_spec_json: if self.dimension_spec_json:
return json.loads(self.dimension_spec_json) return json.loads(self.dimension_spec_json)
return None return None
@ -438,7 +427,7 @@ class DruidMetric(Model, BaseMetric):
return self.json return self.json
@property @property
def json_obj(self) -> Dict: def json_obj(self) -> Dict[str, Any]:
try: try:
obj = json.loads(self.json) obj = json.loads(self.json)
except Exception: except Exception:
@ -614,7 +603,7 @@ class DruidDatasource(Model, BaseDatasource):
name = escape(self.datasource_name) name = escape(self.datasource_name)
return Markup(f'<a href="{url}">{name}</a>') return Markup(f'<a href="{url}">{name}</a>')
def get_metric_obj(self, metric_name: str) -> Dict: def get_metric_obj(self, metric_name: str) -> Dict[str, Any]:
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0] return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
@classmethod @classmethod
@ -705,7 +694,11 @@ class DruidDatasource(Model, BaseDatasource):
@classmethod @classmethod
def sync_to_db_from_config( def sync_to_db_from_config(
cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True cls,
druid_config: Dict[str, Any],
user: User,
cluster: DruidCluster,
refresh: bool = True,
) -> None: ) -> None:
"""Merges the ds config from druid_config into one stored in the db.""" """Merges the ds config from druid_config into one stored in the db."""
session = db.session session = db.session
@ -901,7 +894,7 @@ class DruidDatasource(Model, BaseDatasource):
return postagg_metrics return postagg_metrics
@staticmethod @staticmethod
def recursive_get_fields(_conf: Dict) -> List[str]: def recursive_get_fields(_conf: Dict[str, Any]) -> List[str]:
_type = _conf.get("type") _type = _conf.get("type")
_field = _conf.get("field") _field = _conf.get("field")
_fields = _conf.get("fields") _fields = _conf.get("fields")
@ -957,8 +950,8 @@ class DruidDatasource(Model, BaseDatasource):
@staticmethod @staticmethod
def metrics_and_post_aggs( def metrics_and_post_aggs(
metrics: List[Union[Dict, str]], metrics_dict: Dict[str, DruidMetric], metrics: List[Metric], metrics_dict: Dict[str, DruidMetric],
) -> Tuple[OrderedDict, OrderedDict]: ) -> Tuple["OrderedDict[str, Any]", "OrderedDict[str, Any]"]:
# Separate metrics into those that are aggregations # Separate metrics into those that are aggregations
# and those that are post aggregations # and those that are post aggregations
saved_agg_names = set() saved_agg_names = set()
@ -987,7 +980,7 @@ class DruidDatasource(Model, BaseDatasource):
) )
return aggs, post_aggs return aggs, post_aggs
def values_for_column(self, column_name: str, limit: int = 10000) -> List: def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
"""Retrieve some values for the given column""" """Retrieve some values for the given column"""
logger.info( logger.info(
"Getting values for columns [{}] limited to [{}]".format(column_name, limit) "Getting values for columns [{}] limited to [{}]".format(column_name, limit)
@ -1079,8 +1072,10 @@ class DruidDatasource(Model, BaseDatasource):
@staticmethod @staticmethod
def get_aggregations( def get_aggregations(
metrics_dict: Dict, saved_metrics: Set[str], adhoc_metrics: List[Dict] = [] metrics_dict: Dict[str, Any],
) -> OrderedDict: saved_metrics: Set[str],
adhoc_metrics: Optional[List[Dict[str, Any]]] = None,
) -> "OrderedDict[str, Any]":
""" """
Returns a dictionary of aggregation metric names to aggregation json objects Returns a dictionary of aggregation metric names to aggregation json objects
@ -1089,7 +1084,9 @@ class DruidDatasource(Model, BaseDatasource):
:param adhoc_metrics: list of adhoc metric names :param adhoc_metrics: list of adhoc metric names
:raise SupersetException: if one or more metric names are not aggregations :raise SupersetException: if one or more metric names are not aggregations
""" """
aggregations: OrderedDict = OrderedDict() if not adhoc_metrics:
adhoc_metrics = []
aggregations = OrderedDict()
invalid_metric_names = [] invalid_metric_names = []
for metric_name in saved_metrics: for metric_name in saved_metrics:
if metric_name in metrics_dict: if metric_name in metrics_dict:
@ -1115,7 +1112,7 @@ class DruidDatasource(Model, BaseDatasource):
def get_dimensions( def get_dimensions(
self, columns: List[str], columns_dict: Dict[str, DruidColumn] self, columns: List[str], columns_dict: Dict[str, DruidColumn]
) -> List[Union[str, Dict]]: ) -> List[Union[str, Dict[str, Any]]]:
dimensions = [] dimensions = []
columns = [col for col in columns if col in columns_dict] columns = [col for col in columns if col in columns_dict]
for column_name in columns: for column_name in columns:
@ -1433,7 +1430,7 @@ class DruidDatasource(Model, BaseDatasource):
df[columns] = df[columns].fillna(NULL_STRING).astype("unicode") df[columns] = df[columns].fillna(NULL_STRING).astype("unicode")
return df return df
def query(self, query_obj: Dict) -> QueryResult: def query(self, query_obj: QueryObjectDict) -> QueryResult:
qry_start_dttm = datetime.now() qry_start_dttm = datetime.now()
client = self.cluster.get_pydruid_client() client = self.cluster.get_pydruid_client()
query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2) query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2)
@ -1583,7 +1580,7 @@ class DruidDatasource(Model, BaseDatasource):
dimension=col, value=eq, extraction_function=extraction_fn dimension=col, value=eq, extraction_function=extraction_fn
) )
elif is_list_target: elif is_list_target:
eq = cast(list, eq) eq = cast(List[Any], eq)
fields = [] fields = []
# ignore the filter if it has no value # ignore the filter if it has no value
if not len(eq): if not len(eq):

View File

@ -597,7 +597,7 @@ class SqlaTable(Model, BaseDatasource):
) )
@property @property
def data(self) -> Dict: def data(self) -> Dict[str, Any]:
d = super().data d = super().data
if self.type == "table": if self.type == "table":
grains = self.database.grains() or [] grains = self.database.grains() or []
@ -684,7 +684,9 @@ class SqlaTable(Model, BaseDatasource):
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry") return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
return self.get_sqla_table() return self.get_sqla_table()
def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]: def adhoc_metric_to_sqla(
self, metric: Dict[str, Any], cols: Dict[str, Any]
) -> Optional[Column]:
""" """
Turn an adhoc metric into a sqlalchemy column. Turn an adhoc metric into a sqlalchemy column.
@ -804,7 +806,7 @@ class SqlaTable(Model, BaseDatasource):
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
select_exprs: List[Column] = [] select_exprs: List[Column] = []
groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() groupby_exprs_sans_timestamp = OrderedDict()
if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby): if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
# dedup columns while preserving order # dedup columns while preserving order
@ -874,7 +876,7 @@ class SqlaTable(Model, BaseDatasource):
qry = qry.group_by(*groupby_exprs_with_timestamp.values()) qry = qry.group_by(*groupby_exprs_with_timestamp.values())
where_clause_and = [] where_clause_and = []
having_clause_and: List = [] having_clause_and = []
for flt in filter: # type: ignore for flt in filter: # type: ignore
if not all([flt.get(s) for s in ["col", "op"]]): if not all([flt.get(s) for s in ["col", "op"]]):
@ -1082,7 +1084,10 @@ class SqlaTable(Model, BaseDatasource):
return ob return ob
def _get_top_groups( def _get_top_groups(
self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict self,
df: pd.DataFrame,
dimensions: List[str],
groupby_exprs: "OrderedDict[str, Any]",
) -> ColumnElement: ) -> ColumnElement:
groups = [] groups = []
for unused, row in df.iterrows(): for unused, row in df.iterrows():

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
@ -75,7 +75,7 @@ class BaseDAO:
return query.all() return query.all()
@classmethod @classmethod
def create(cls, properties: Dict, commit: bool = True) -> Model: def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
""" """
Generic for creating models Generic for creating models
:raises: DAOCreateFailedError :raises: DAOCreateFailedError
@ -95,7 +95,9 @@ class BaseDAO:
return model return model
@classmethod @classmethod
def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model: def update(
cls, model: Model, properties: Dict[str, Any], commit: bool = True
) -> Model:
""" """
Generic update a model Generic update a model
:raises: DAOCreateFailedError :raises: DAOCreateFailedError

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
class CreateDashboardCommand(BaseCommand): class CreateDashboardCommand(BaseCommand):
def __init__(self, user: User, data: Dict): def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user self._actor = user
self._properties = data.copy() self._properties = data.copy()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class UpdateDashboardCommand(BaseCommand): class UpdateDashboardCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict): def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user self._actor = user
self._model_id = model_id self._model_id = model_id
self._properties = data.copy() self._properties = data.copy()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
class CreateDatasetCommand(BaseCommand): class CreateDatasetCommand(BaseCommand):
def __init__(self, user: User, data: Dict): def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user self._actor = user
self._properties = data.copy() self._properties = data.copy()

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
import logging import logging
from collections import Counter from collections import Counter
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
@ -48,7 +48,7 @@ logger = logging.getLogger(__name__)
class UpdateDatasetCommand(BaseCommand): class UpdateDatasetCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict): def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user self._actor = user
self._model_id = model_id self._model_id = model_id
self._properties = data.copy() self._properties = data.copy()
@ -111,7 +111,7 @@ class UpdateDatasetCommand(BaseCommand):
raise exception raise exception
def _validate_columns( def _validate_columns(
self, columns: List[Dict], exceptions: List[ValidationError] self, columns: List[Dict[str, Any]], exceptions: List[ValidationError]
) -> None: ) -> None:
# Validate duplicates on data # Validate duplicates on data
if self._get_duplicates(columns, "column_name"): if self._get_duplicates(columns, "column_name"):
@ -133,7 +133,7 @@ class UpdateDatasetCommand(BaseCommand):
exceptions.append(DatasetColumnsExistsValidationError()) exceptions.append(DatasetColumnsExistsValidationError())
def _validate_metrics( def _validate_metrics(
self, metrics: List[Dict], exceptions: List[ValidationError] self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError]
) -> None: ) -> None:
if self._get_duplicates(metrics, "metric_name"): if self._get_duplicates(metrics, "metric_name"):
exceptions.append(DatasetMetricsDuplicateValidationError()) exceptions.append(DatasetMetricsDuplicateValidationError())
@ -152,7 +152,7 @@ class UpdateDatasetCommand(BaseCommand):
exceptions.append(DatasetMetricsExistsValidationError()) exceptions.append(DatasetMetricsExistsValidationError())
@staticmethod @staticmethod
def _get_duplicates(data: List[Dict], key: str) -> List[str]: def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]:
duplicates = [ duplicates = [
name name
for name, count in Counter([item[key] for item in data]).items() for name, count in Counter([item[key] for item in data]).items()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from flask import current_app from flask import current_app
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -116,7 +116,7 @@ class DatasetDAO(BaseDAO):
@classmethod @classmethod
def update( def update(
cls, model: SqlaTable, properties: Dict, commit: bool = True cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlaTable]: ) -> Optional[SqlaTable]:
""" """
Updates a Dataset model on the metadata DB Updates a Dataset model on the metadata DB
@ -151,13 +151,13 @@ class DatasetDAO(BaseDAO):
@classmethod @classmethod
def update_column( def update_column(
cls, model: TableColumn, properties: Dict, commit: bool = True cls, model: TableColumn, properties: Dict[str, Any], commit: bool = True
) -> Optional[TableColumn]: ) -> Optional[TableColumn]:
return DatasetColumnDAO.update(model, properties, commit=commit) return DatasetColumnDAO.update(model, properties, commit=commit)
@classmethod @classmethod
def create_column( def create_column(
cls, properties: Dict, commit: bool = True cls, properties: Dict[str, Any], commit: bool = True
) -> Optional[TableColumn]: ) -> Optional[TableColumn]:
""" """
Creates a Dataset model on the metadata DB Creates a Dataset model on the metadata DB
@ -166,13 +166,13 @@ class DatasetDAO(BaseDAO):
@classmethod @classmethod
def update_metric( def update_metric(
cls, model: SqlMetric, properties: Dict, commit: bool = True cls, model: SqlMetric, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlMetric]: ) -> Optional[SqlMetric]:
return DatasetMetricDAO.update(model, properties, commit=commit) return DatasetMetricDAO.update(model, properties, commit=commit)
@classmethod @classmethod
def create_metric( def create_metric(
cls, properties: Dict, commit: bool = True cls, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlMetric]: ) -> Optional[SqlMetric]:
""" """
Creates a Dataset model on the metadata DB Creates a Dataset model on the metadata DB

View File

@ -151,7 +151,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
try_remove_schema_from_table_name = True # pylint: disable=invalid-name try_remove_schema_from_table_name = True # pylint: disable=invalid-name
# default matching patterns for identifying column types # default matching patterns for identifying column types
db_column_types: Dict[utils.DbColumnType, Tuple[Pattern, ...]] = { db_column_types: Dict[utils.DbColumnType, Tuple[Pattern[Any], ...]] = {
utils.DbColumnType.NUMERIC: ( utils.DbColumnType.NUMERIC: (
re.compile(r".*DOUBLE.*", re.IGNORECASE), re.compile(r".*DOUBLE.*", re.IGNORECASE),
re.compile(r".*FLOAT.*", re.IGNORECASE), re.compile(r".*FLOAT.*", re.IGNORECASE),
@ -296,7 +296,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return select_exprs return select_exprs
@classmethod @classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
""" """
:param cursor: Cursor instance :param cursor: Cursor instance
@ -311,8 +311,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod @classmethod
def expand_data( def expand_data(
cls, columns: List[dict], data: List[dict] cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
) -> Tuple[List[dict], List[dict], List[dict]]: ) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
""" """
Some engines support expanding nested fields. See implementation in Presto Some engines support expanding nested fields. See implementation in Presto
spec for details. spec for details.
@ -645,7 +645,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
schema: Optional[str], schema: Optional[str],
database: "Database", database: "Database",
query: Select, query: Select,
columns: Optional[List] = None, columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]: ) -> Optional[Select]:
""" """
Add a where clause to a query to reference only the most recent partition Add a where clause to a query to reference only the most recent partition
@ -925,7 +925,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return [] return []
@staticmethod @staticmethod
def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple]: def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]:
""" """
Convert pyodbc.Row objects from `fetch_data` to tuples. Convert pyodbc.Row objects from `fetch_data` to tuples.

View File

@ -83,7 +83,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
return None return None
@classmethod @classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
data = super().fetch_data(cursor, limit) data = super().fetch_data(cursor, limit)
# Support type BigQuery Row, introduced here PR #4071 # Support type BigQuery Row, introduced here PR #4071
# google.cloud.bigquery.table.Row # google.cloud.bigquery.table.Row

View File

@ -39,7 +39,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
} }
@classmethod @classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
data = super().fetch_data(cursor, limit) data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further # Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data) return cls.pyodbc_rows_to_tuples(data)

View File

@ -93,7 +93,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return BaseEngineSpec.get_all_datasource_names(database, datasource_type) return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
@classmethod @classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
import pyhive import pyhive
from TCLIService import ttypes from TCLIService import ttypes
@ -304,7 +304,7 @@ class HiveEngineSpec(PrestoEngineSpec):
schema: Optional[str], schema: Optional[str],
database: "Database", database: "Database",
query: Select, query: Select,
columns: Optional[List] = None, columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]: ) -> Optional[Select]:
try: try:
col_names, values = cls.latest_partition( col_names, values = cls.latest_partition(
@ -323,7 +323,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return None return None
@classmethod @classmethod
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]: def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
@classmethod @classmethod

View File

@ -66,7 +66,7 @@ class MssqlEngineSpec(BaseEngineSpec):
return None return None
@classmethod @classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
data = super().fetch_data(cursor, limit) data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further # Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data) return cls.pyodbc_rows_to_tuples(data)

View File

@ -51,7 +51,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
} }
@classmethod @classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
cursor.tzinfo_factory = FixedOffsetTimezone cursor.tzinfo_factory = FixedOffsetTimezone
if not cursor.description: if not cursor.description:
return [] return []

View File

@ -164,7 +164,7 @@ class PrestoEngineSpec(BaseEngineSpec):
return [row[0] for row in results] return [row[0] for row in results]
@classmethod @classmethod
def _create_column_info(cls, name: str, data_type: str) -> dict: def _create_column_info(cls, name: str, data_type: str) -> Dict[str, Any]:
""" """
Create column info object Create column info object
:param name: column name :param name: column name
@ -213,7 +213,10 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branches def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branches
cls, parent_column_name: str, parent_data_type: str, result: List[dict] cls,
parent_column_name: str,
parent_data_type: str,
result: List[Dict[str, Any]],
) -> None: ) -> None:
""" """
Parse a row or array column Parse a row or array column
@ -322,7 +325,7 @@ class PrestoEngineSpec(BaseEngineSpec):
(i.e. column name and data type) (i.e. column name and data type)
""" """
columns = cls._show_columns(inspector, table_name, schema) columns = cls._show_columns(inspector, table_name, schema)
result: List[dict] = [] result: List[Dict[str, Any]] = []
for column in columns: for column in columns:
try: try:
# parse column if it is a row or array # parse column if it is a row or array
@ -361,7 +364,7 @@ class PrestoEngineSpec(BaseEngineSpec):
return column_name.startswith('"') and column_name.endswith('"') return column_name.startswith('"') and column_name.endswith('"')
@classmethod @classmethod
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]: def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
""" """
Format column clauses where names are in quotes and labels are specified Format column clauses where names are in quotes and labels are specified
:param cols: columns :param cols: columns
@ -561,8 +564,8 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod @classmethod
def expand_data( # pylint: disable=too-many-locals def expand_data( # pylint: disable=too-many-locals
cls, columns: List[dict], data: List[dict] cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
) -> Tuple[List[dict], List[dict], List[dict]]: ) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
""" """
We do not immediately display rows and arrays clearly in the data grid. This We do not immediately display rows and arrays clearly in the data grid. This
method separates out nested fields and data values to help clearly display method separates out nested fields and data values to help clearly display
@ -590,7 +593,7 @@ class PrestoEngineSpec(BaseEngineSpec):
# process each column, unnesting ARRAY types and # process each column, unnesting ARRAY types and
# expanding ROW types into new columns # expanding ROW types into new columns
to_process = deque((column, 0) for column in columns) to_process = deque((column, 0) for column in columns)
all_columns: List[dict] = [] all_columns: List[Dict[str, Any]] = []
expanded_columns = [] expanded_columns = []
current_array_level = None current_array_level = None
while to_process: while to_process:
@ -843,7 +846,7 @@ class PrestoEngineSpec(BaseEngineSpec):
schema: Optional[str], schema: Optional[str],
database: "Database", database: "Database",
query: Select, query: Select,
columns: Optional[List] = None, columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]: ) -> Optional[Select]:
try: try:
col_names, values = cls.latest_partition( col_names, values = cls.latest_partition(

View File

@ -95,7 +95,9 @@ class UIManifestProcessor:
self.parse_manifest_json() self.parse_manifest_json()
@app.context_processor @app.context_processor
def get_manifest() -> Dict[str, Callable]: # pylint: disable=unused-variable def get_manifest() -> Dict[ # pylint: disable=unused-variable
str, Callable[[str], List[str]]
]:
loaded_chunks = set() loaded_chunks = set()
def get_files(bundle: str, asset_type: str = "js") -> List[str]: def get_files(bundle: str, asset_type: str = "js") -> List[str]:
@ -131,7 +133,7 @@ appbuilder = AppBuilder(update_perms=False)
cache_manager = CacheManager() cache_manager = CacheManager()
celery_app = celery.Celery() celery_app = celery.Celery()
db = SQLA() db = SQLA()
_event_logger: dict = {} _event_logger: Dict[str, Any] = {}
event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
feature_flag_manager = FeatureFlagManager() feature_flag_manager = FeatureFlagManager()
jinja_context_manager = JinjaContextManager() jinja_context_manager = JinjaContextManager()

View File

@ -341,11 +341,14 @@ class Database(
def get_reserved_words(self) -> Set[str]: def get_reserved_words(self) -> Set[str]:
return self.get_dialect().preparer.reserved_words return self.get_dialect().preparer.reserved_words
def get_quoter(self) -> Callable: def get_quoter(self) -> Callable[[str, Any], str]:
return self.get_dialect().identifier_preparer.quote return self.get_dialect().identifier_preparer.quote
def get_df( # pylint: disable=too-many-locals def get_df( # pylint: disable=too-many-locals
self, sql: str, schema: Optional[str] = None, mutator: Optional[Callable] = None self,
sql: str,
schema: Optional[str] = None,
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
) -> pd.DataFrame: ) -> pd.DataFrame:
sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)] sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)]
@ -450,7 +453,7 @@ class Database(
@cache_util.memoized_func( @cache_util.memoized_func(
key=lambda *args, **kwargs: "db:{}:schema:None:view_list", key=lambda *args, **kwargs: "db:{}:schema:None:view_list",
attribute_in_key="id", # type: ignore attribute_in_key="id",
) )
def get_all_view_names_in_database( def get_all_view_names_in_database(
self, self,

View File

@ -240,7 +240,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
self.json_metadata = value self.json_metadata = value
@property @property
def position(self) -> Dict: def position(self) -> Dict[str, Any]:
if self.position_json: if self.position_json:
return json.loads(self.position_json) return json.loads(self.position_json)
return {} return {}
@ -315,7 +315,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
old_to_new_slc_id_dict: Dict[int, int] = {} old_to_new_slc_id_dict: Dict[int, int] = {}
new_timed_refresh_immune_slices = [] new_timed_refresh_immune_slices = []
new_expanded_slices = {} new_expanded_slices = {}
new_filter_scopes: Dict[str, Dict] = {} new_filter_scopes = {}
i_params_dict = dashboard_to_import.params_dict i_params_dict = dashboard_to_import.params_dict
remote_id_slice_map = { remote_id_slice_map = {
slc.params_dict["remote_id"]: slc slc.params_dict["remote_id"]: slc
@ -351,7 +351,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
# are converted to filter_scopes # are converted to filter_scopes
# but dashboard create from import may still have old dashboard filter metadata # but dashboard create from import may still have old dashboard filter metadata
# here we convert them to new filter_scopes metadata first # here we convert them to new filter_scopes metadata first
filter_scopes: Dict = {} filter_scopes = {}
if ( if (
"filter_immune_slices" in i_params_dict "filter_immune_slices" in i_params_dict
or "filter_immune_slice_fields" in i_params_dict or "filter_immune_slice_fields" in i_params_dict
@ -415,7 +415,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
@classmethod @classmethod
def export_dashboards( # pylint: disable=too-many-locals def export_dashboards( # pylint: disable=too-many-locals
cls, dashboard_ids: List cls, dashboard_ids: List[int]
) -> str: ) -> str:
copied_dashboards = [] copied_dashboards = []
datasource_ids = set() datasource_ids = set()

View File

@ -81,7 +81,7 @@ class ImportMixin:
for u in cls.__table_args__ # type: ignore for u in cls.__table_args__ # type: ignore
if isinstance(u, UniqueConstraint) if isinstance(u, UniqueConstraint)
] ]
unique.extend( # type: ignore unique.extend(
{c.name} for c in cls.__table__.columns if c.unique # type: ignore {c.name} for c in cls.__table__.columns if c.unique # type: ignore
) )
return unique return unique

View File

@ -36,7 +36,7 @@ from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils import core as utils from superset.utils import core as utils
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
from superset.viz_sip38 import BaseViz, viz_types # type: ignore from superset.viz_sip38 import BaseViz, viz_types
else: else:
from superset.viz import BaseViz, viz_types # type: ignore from superset.viz import BaseViz, viz_types # type: ignore

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Any, Optional, Type from typing import Any, Dict, List, Optional, Type
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer from sqlalchemy.sql.sqltypes import Integer
@ -29,7 +29,7 @@ class TinyInteger(Integer):
A type for tiny ``int`` integers. A type for tiny ``int`` integers.
""" """
def python_type(self) -> Type: def python_type(self) -> Type[int]:
return int return int
@classmethod @classmethod
@ -42,7 +42,7 @@ class Interval(TypeEngine):
A type for intervals. A type for intervals.
""" """
def python_type(self) -> Optional[Type]: def python_type(self) -> Optional[Type[Any]]:
return None return None
@classmethod @classmethod
@ -55,7 +55,7 @@ class Array(TypeEngine):
A type for arrays. A type for arrays.
""" """
def python_type(self) -> Optional[Type]: def python_type(self) -> Optional[Type[List[Any]]]:
return list return list
@classmethod @classmethod
@ -68,7 +68,7 @@ class Map(TypeEngine):
A type for maps. A type for maps.
""" """
def python_type(self) -> Optional[Type]: def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
return dict return dict
@classmethod @classmethod
@ -81,7 +81,7 @@ class Row(TypeEngine):
A type for rows. A type for rows.
""" """
def python_type(self) -> Optional[Type]: def python_type(self) -> Optional[Type[Any]]:
return None return None
@classmethod @classmethod

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Callable from typing import Any
from flask import g from flask import g
from flask_sqlalchemy import BaseQuery from flask_sqlalchemy import BaseQuery
@ -25,7 +25,7 @@ from superset.views.base import BaseFilter
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query: BaseQuery, value: Callable) -> BaseQuery: def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
""" """
Filter queries to only those owned by current user. If Filter queries to only those owned by current user. If
can_access_all_queries permission is set a user can list all queries can_access_all_queries permission is set a user can list all queries

View File

@ -20,7 +20,7 @@
import datetime import datetime
import json import json
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -64,7 +64,7 @@ def stringify(obj: Any) -> str:
def stringify_values(array: np.ndarray) -> np.ndarray: def stringify_values(array: np.ndarray) -> np.ndarray:
vstringify: Callable = np.vectorize(stringify) vstringify = np.vectorize(stringify)
return vstringify(array) return vstringify(array)
@ -172,7 +172,7 @@ class SupersetResultSet:
return table.to_pandas(integer_object_nulls=True) return table.to_pandas(integer_object_nulls=True)
@staticmethod @staticmethod
def first_nonempty(items: List) -> Any: def first_nonempty(items: List[Any]) -> Any:
return next((i for i in items if i), None) return next((i for i in items if i), None)
def is_temporal(self, db_type_str: Optional[str]) -> bool: def is_temporal(self, db_type_str: Optional[str]) -> bool:

View File

@ -21,11 +21,11 @@ from typing import Any, Callable, List, Optional, Set, Tuple, TYPE_CHECKING, Uni
from flask import current_app, g from flask import current_app, g
from flask_appbuilder import Model from flask_appbuilder import Model
from flask_appbuilder.security.sqla import models as ab_models
from flask_appbuilder.security.sqla.manager import SecurityManager from flask_appbuilder.security.sqla.manager import SecurityManager
from flask_appbuilder.security.sqla.models import ( from flask_appbuilder.security.sqla.models import (
assoc_permissionview_role, assoc_permissionview_role,
assoc_user_role, assoc_user_role,
PermissionView,
) )
from flask_appbuilder.security.views import ( from flask_appbuilder.security.views import (
PermissionModelView, PermissionModelView,
@ -602,11 +602,8 @@ class SupersetSecurityManager(SecurityManager):
logger.info("Cleaning faulty perms") logger.info("Cleaning faulty perms")
sesh = self.get_session sesh = self.get_session
pvms = sesh.query(ab_models.PermissionView).filter( pvms = sesh.query(PermissionView).filter(
or_( or_(PermissionView.permission == None, PermissionView.view_menu == None,)
ab_models.PermissionView.permission == None,
ab_models.PermissionView.view_menu == None,
)
) )
deleted_count = pvms.delete() deleted_count = pvms.delete()
sesh.commit() sesh.commit()
@ -640,7 +637,9 @@ class SupersetSecurityManager(SecurityManager):
self.get_session.commit() self.get_session.commit()
self.clean_perms() self.clean_perms()
def set_role(self, role_name: str, pvm_check: Callable) -> None: def set_role(
self, role_name: str, pvm_check: Callable[[PermissionView], bool]
) -> None:
""" """
Set the FAB permission/views for the role. Set the FAB permission/views for the role.
@ -650,7 +649,7 @@ class SupersetSecurityManager(SecurityManager):
logger.info("Syncing {} perms".format(role_name)) logger.info("Syncing {} perms".format(role_name))
sesh = self.get_session sesh = self.get_session
pvms = sesh.query(ab_models.PermissionView).all() pvms = sesh.query(PermissionView).all()
pvms = [p for p in pvms if p.permission and p.view_menu] pvms = [p for p in pvms if p.permission and p.view_menu]
role = self.add_role(role_name) role = self.add_role(role_name)
role_pvms = [p for p in pvms if pvm_check(p)] role_pvms = [p for p in pvms if pvm_check(p)]

View File

@ -299,9 +299,10 @@ def _serialize_and_expand_data(
db_engine_spec: BaseEngineSpec, db_engine_spec: BaseEngineSpec,
use_msgpack: Optional[bool] = False, use_msgpack: Optional[bool] = False,
expand_data: bool = False, expand_data: bool = False,
) -> Tuple[Union[bytes, str], list, list, list]: ) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]:
selected_columns: List[Dict] = result_set.columns selected_columns = result_set.columns
expanded_columns: List[Dict] all_columns: List[Any]
expanded_columns: List[Any]
if use_msgpack: if use_msgpack:
with stats_timing( with stats_timing(

View File

@ -25,7 +25,7 @@ from superset import create_app
from superset.extensions import celery_app from superset.extensions import celery_app
# Init the Flask app / configure everything # Init the Flask app / configure everything
create_app() # type: ignore create_app()
# Need to import late, as the celery_app will have been setup by "create_app()" # Need to import late, as the celery_app will have been setup by "create_app()"
# pylint: disable=wrong-import-position, unused-import # pylint: disable=wrong-import-position, unused-import

View File

@ -23,7 +23,7 @@ import urllib.request
from collections import namedtuple from collections import namedtuple
from datetime import datetime, timedelta from datetime import datetime, timedelta
from email.utils import make_msgid, parseaddr from email.utils import make_msgid, parseaddr
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
from urllib.error import URLError # pylint: disable=ungrouped-imports from urllib.error import URLError # pylint: disable=ungrouped-imports
import croniter import croniter
@ -36,7 +36,6 @@ from flask_login import login_user
from retry.api import retry_call from retry.api import retry_call
from selenium.common.exceptions import WebDriverException from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox from selenium.webdriver import chrome, firefox
from werkzeug.datastructures import TypeConversionDict
from werkzeug.http import parse_cookie from werkzeug.http import parse_cookie
# Superset framework imports # Superset framework imports
@ -53,6 +52,11 @@ from superset.models.schedules import (
) )
from superset.utils.core import get_email_address_list, send_email_smtp from superset.utils.core import get_email_address_list, send_email_smtp
if TYPE_CHECKING:
# pylint: disable=unused-import
from werkzeug.datastructures import TypeConversionDict
# Globals # Globals
config = app.config config = app.config
logger = logging.getLogger("tasks.email_reports") logger = logging.getLogger("tasks.email_reports")
@ -131,7 +135,7 @@ def _generate_mail_content(
return EmailContent(body, data, images) return EmailContent(body, data, images)
def _get_auth_cookies() -> List[TypeConversionDict]: def _get_auth_cookies() -> List["TypeConversionDict[Any, Any]"]:
# Login with the user specified to get the reports # Login with the user specified to get the reports
with app.test_request_context(): with app.test_request_context():
user = security_manager.find_user(config["EMAIL_REPORTS_USER"]) user = security_manager.find_user(config["EMAIL_REPORTS_USER"])

View File

@ -27,8 +27,9 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-
def memoized_func( def memoized_func(
key: Callable = view_cache_key, attribute_in_key: Optional[str] = None key: Callable[..., str] = view_cache_key, # pylint: disable=bad-whitespace
) -> Callable: attribute_in_key: Optional[str] = None,
) -> Callable[..., Any]:
"""Use this decorator to cache functions that have predefined first arg. """Use this decorator to cache functions that have predefined first arg.
enable_cache is treated as True by default, enable_cache is treated as True by default,
@ -45,7 +46,7 @@ def memoized_func(
returns the caching key. returns the caching key.
""" """
def wrap(f: Callable) -> Callable: def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
if cache_manager.tables_cache: if cache_manager.tables_cache:
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any: def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:

View File

@ -85,7 +85,7 @@ from superset.exceptions import (
SupersetException, SupersetException,
SupersetTimeoutException, SupersetTimeoutException,
) )
from superset.typing import FormData, Metric from superset.typing import FlaskResponse, FormData, Metric
from superset.utils.dates import datetime_to_epoch, EPOCH from superset.utils.dates import datetime_to_epoch, EPOCH
try: try:
@ -147,7 +147,9 @@ class _memoized:
should account for instance variable changes. should account for instance variable changes.
""" """
def __init__(self, func: Callable, watch: Optional[Tuple[str, ...]] = None) -> None: def __init__(
self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None
) -> None:
self.func = func self.func = func
self.cache: Dict[Any, Any] = {} self.cache: Dict[Any, Any] = {}
self.is_method = False self.is_method = False
@ -173,7 +175,7 @@ class _memoized:
"""Return the function's docstring.""" """Return the function's docstring."""
return self.func.__doc__ or "" return self.func.__doc__ or ""
def __get__(self, obj: Any, objtype: Type) -> functools.partial: def __get__(self, obj: Any, objtype: Type[Any]) -> functools.partial: # type: ignore
if not self.is_method: if not self.is_method:
self.is_method = True self.is_method = True
"""Support instance methods.""" """Support instance methods."""
@ -181,13 +183,13 @@ class _memoized:
def memoized( def memoized(
func: Optional[Callable] = None, watch: Optional[Tuple[str, ...]] = None func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None
) -> Callable: ) -> Callable[..., Any]:
if func: if func:
return _memoized(func) return _memoized(func)
else: else:
def wrapper(f: Callable) -> Callable: def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
return _memoized(f, watch) return _memoized(f, watch)
return wrapper return wrapper
@ -1241,7 +1243,9 @@ def create_ssl_cert_file(certificate: str) -> str:
return path return path
def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]: def time_function(
func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any
) -> Tuple[float, Any]:
""" """
Measures the amount of time a function takes to execute in ms Measures the amount of time a function takes to execute in ms

View File

@ -29,7 +29,7 @@ def convert_filter_scopes(
) -> Dict[int, Dict[str, Dict[str, Any]]]: ) -> Dict[int, Dict[str, Dict[str, Any]]]:
filter_scopes = {} filter_scopes = {}
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or [] immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
immuned_by_column: Dict = defaultdict(list) immuned_by_column: Dict[str, List[int]] = defaultdict(list)
for slice_id, columns in json_metadata.get( for slice_id, columns in json_metadata.get(
"filter_immune_slice_fields", {} "filter_immune_slice_fields", {}
).items(): ).items():
@ -52,7 +52,7 @@ def convert_filter_scopes(
logging.info(f"slice [{filter_id}] has invalid field: {filter_field}") logging.info(f"slice [{filter_id}] has invalid field: {filter_field}")
for filter_slice in filters: for filter_slice in filters:
filter_fields: Dict = {} filter_fields: Dict[str, Dict[str, Any]] = {}
filter_id = filter_slice.id filter_id = filter_slice.id
slice_params = json.loads(filter_slice.params or "{}") slice_params = json.loads(filter_slice.params or "{}")
configs = slice_params.get("filter_configs") or [] configs = slice_params.get("filter_configs") or []
@ -77,9 +77,10 @@ def convert_filter_scopes(
def copy_filter_scopes( def copy_filter_scopes(
old_to_new_slc_id_dict: Dict[int, int], old_filter_scopes: Dict[str, Dict] old_to_new_slc_id_dict: Dict[int, int],
) -> Dict: old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]],
new_filter_scopes: Dict[str, Dict] = {} ) -> Dict[str, Dict[Any, Any]]:
new_filter_scopes: Dict[str, Dict[Any, Any]] = {}
for (filter_id, scopes) in old_filter_scopes.items(): for (filter_id, scopes) in old_filter_scopes.items():
new_filter_key = old_to_new_slc_id_dict.get(int(filter_id)) new_filter_key = old_to_new_slc_id_dict.get(int(filter_id))
if new_filter_key: if new_filter_key:

View File

@ -46,7 +46,7 @@ def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[floa
stats_logger.timing(stats_key, now_as_float() - start_ts) stats_logger.timing(stats_key, now_as_float() - start_ts)
def etag_cache(max_age: int, check_perms: Callable) -> Callable: def etag_cache(max_age: int, check_perms: Callable[..., Any]) -> Callable[..., Any]:
""" """
A decorator for caching views and handling etag conditional requests. A decorator for caching views and handling etag conditional requests.
@ -60,7 +60,7 @@ def etag_cache(max_age: int, check_perms: Callable) -> Callable:
""" """
def decorator(f: Callable) -> Callable: def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f) @wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin: def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
# check if the user can access the resource # check if the user can access the resource

View File

@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
def import_datasource( def import_datasource(
session: Session, session: Session,
i_datasource: Model, i_datasource: Model,
lookup_database: Callable, lookup_database: Callable[[Model], Model],
lookup_datasource: Callable, lookup_datasource: Callable[[Model], Model],
import_time: Optional[int] = None, import_time: Optional[int] = None,
) -> int: ) -> int:
"""Imports the datasource from the object to the database. """Imports the datasource from the object to the database.
@ -82,7 +82,9 @@ def import_datasource(
return datasource.id return datasource.id
def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model: def import_simple_obj(
session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
) -> Model:
make_transient(i_obj) make_transient(i_obj)
i_obj.id = None i_obj.id = None
i_obj.table = None i_obj.table = None

View File

@ -35,7 +35,7 @@ class AbstractEventLogger(ABC):
) -> None: ) -> None:
pass pass
def log_this(self, f: Callable) -> Callable: def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(f) @functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any: def wrapper(*args: Any, **kwargs: Any) -> Any:
user_id = None user_id = None
@ -124,7 +124,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
) )
) )
event_logger_type = cast(Type, cfg_value) event_logger_type = cast(Type[Any], cfg_value)
result = event_logger_type() result = event_logger_type()
# Verify that we have a valid logger impl # Verify that we have a valid logger impl

View File

@ -58,7 +58,7 @@ class DefaultLoggingConfigurator(LoggingConfigurator):
if app_config["ENABLE_TIME_ROTATE"]: if app_config["ENABLE_TIME_ROTATE"]:
logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"]) logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"])
handler = TimedRotatingFileHandler( # type: ignore handler = TimedRotatingFileHandler(
app_config["FILENAME"], app_config["FILENAME"],
when=app_config["ROLLOVER"], when=app_config["ROLLOVER"],
interval=app_config["INTERVAL"], interval=app_config["INTERVAL"],

View File

@ -72,8 +72,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
) )
def validate_column_args(*argnames: str) -> Callable: def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable) -> Callable: def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any: def wrapped(df: DataFrame, **options: Any) -> Any:
columns = df.columns.tolist() columns = df.columns.tolist()
for name in argnames: for name in argnames:
@ -471,7 +471,7 @@ def geodetic_parse(
Parse a string containing a geodetic point and return latitude, longitude Parse a string containing a geodetic point and return latitude, longitude
and altitude and altitude
""" """
point = Point(location) # type: ignore point = Point(location)
return point[0], point[1], point[2] return point[0], point[1], point[2]
try: try:

View File

@ -51,7 +51,7 @@ SELENIUM_HEADSTART = 3
WindowSize = Tuple[int, int] WindowSize = Tuple[int, int]
def get_auth_cookies(user: "User") -> List[Dict]: def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]:
# Login with the user specified to get the reports # Login with the user specified to get the reports
with current_app.test_request_context("/login"): with current_app.test_request_context("/login"):
login_user(user) login_user(user)
@ -101,14 +101,14 @@ class AuthWebDriverProxy:
self, self,
driver_type: str, driver_type: str,
window: Optional[WindowSize] = None, window: Optional[WindowSize] = None,
auth_func: Optional[Callable] = None, auth_func: Optional[
Callable[..., Any]
] = None, # pylint: disable=bad-whitespace
): ):
self._driver_type = driver_type self._driver_type = driver_type
self._window: WindowSize = window or (800, 600) self._window: WindowSize = window or (800, 600)
config_auth_func: Callable = current_app.config.get( config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", auth_driver)
"WEBDRIVER_AUTH_FUNC", auth_driver self._auth_func = auth_func or config_auth_func
)
self._auth_func: Callable = auth_func or config_auth_func
def create(self) -> WebDriver: def create(self) -> WebDriver:
if self._driver_type == "firefox": if self._driver_type == "firefox":
@ -123,7 +123,7 @@ class AuthWebDriverProxy:
raise Exception(f"Webdriver name ({self._driver_type}) not supported") raise Exception(f"Webdriver name ({self._driver_type}) not supported")
# Prepare args for the webdriver init # Prepare args for the webdriver init
options.add_argument("--headless") options.add_argument("--headless")
kwargs: Dict = dict(options=options) kwargs: Dict[Any, Any] = dict(options=options)
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"]) kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
logger.info("Init selenium driver") logger.info("Init selenium driver")
return driver_class(**kwargs) return driver_class(**kwargs)

View File

@ -143,7 +143,7 @@ def generate_download_headers(
return headers return headers
def api(f: Callable) -> Callable: def api(f: Callable[..., FlaskResponse]) -> Callable[..., FlaskResponse]:
""" """
A decorator to label an endpoint as an API. Catches uncaught exceptions and A decorator to label an endpoint as an API. Catches uncaught exceptions and
return the response in the JSON format return the response in the JSON format
@ -383,11 +383,11 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
:param primary_key: :param primary_key:
record primary key to delete record primary key to delete
""" """
item = self.datamodel.get(primary_key, self._base_filters) # type: ignore item = self.datamodel.get(primary_key, self._base_filters)
if not item: if not item:
abort(404) abort(404)
try: try:
self.pre_delete(item) # type: ignore self.pre_delete(item)
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
flash(str(ex), "danger") flash(str(ex), "danger")
else: else:
@ -400,8 +400,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
.all() .all()
) )
if self.datamodel.delete(item): # type: ignore if self.datamodel.delete(item):
self.post_delete(item) # type: ignore self.post_delete(item)
for pv in pvs: for pv in pvs:
security_manager.get_session.delete(pv) security_manager.get_session.delete(pv)
@ -411,8 +411,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
security_manager.get_session.commit() security_manager.get_session.commit()
flash(*self.datamodel.message) # type: ignore flash(*self.datamodel.message)
self.update_redirect() # type: ignore self.update_redirect()
@action( @action(
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False "muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False

View File

@ -41,7 +41,7 @@ get_related_schema = {
} }
def statsd_metrics(f: Callable) -> Callable: def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]:
""" """
Handle sending all statsd metrics from the REST API Handle sending all statsd metrics from the REST API
""" """

View File

@ -88,7 +88,9 @@ class BaseOwnedSchema(BaseSupersetSchema):
owners_field_name = "owners" owners_field_name = "owners"
@post_load @post_load
def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model: def make_object(
self, data: Dict[str, Any], discard: Optional[List[str]] = None
) -> Model:
discard = discard or [] discard = discard or []
discard.append(self.owners_field_name) discard.append(self.owners_field_name)
instance = super().make_object(data, discard) instance = super().make_object(data, discard)

View File

@ -251,7 +251,7 @@ def check_slice_perms(self: "Superset", slice_id: int) -> None:
def _deserialize_results_payload( def _deserialize_results_payload(
payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False
) -> Dict[Any, Any]: ) -> Dict[str, Any]:
logger.debug(f"Deserializing from msgpack: {use_msgpack}") logger.debug(f"Deserializing from msgpack: {use_msgpack}")
if use_msgpack: if use_msgpack:
with stats_timing( with stats_timing(
@ -278,7 +278,7 @@ def _deserialize_results_payload(
with stats_timing( with stats_timing(
"sqllab.query.results_backend_json_deserialize", stats_logger "sqllab.query.results_backend_json_deserialize", stats_logger
): ):
return json.loads(payload) # type: ignore return json.loads(payload)
def get_cta_schema_name( def get_cta_schema_name(
@ -1343,7 +1343,7 @@ class Superset(BaseSupersetView):
if "timed_refresh_immune_slices" not in md: if "timed_refresh_immune_slices" not in md:
md["timed_refresh_immune_slices"] = [] md["timed_refresh_immune_slices"] = []
new_filter_scopes: Dict[str, Dict] = {} new_filter_scopes = {}
if "filter_scopes" in data: if "filter_scopes" in data:
# replace filter_id and immune ids from old slice id to new slice id: # replace filter_id and immune ids from old slice id to new slice id:
# and remove slice ids that are not in dash anymore # and remove slice ids that are not in dash anymore
@ -2137,7 +2137,7 @@ class Superset(BaseSupersetView):
f"deprecated.{self.__class__.__name__}.select_star.database_not_found" f"deprecated.{self.__class__.__name__}.select_star.database_not_found"
) )
return json_error_response("Not found", 404) return json_error_response("Not found", 404)
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
table_name = utils.parse_js_uri_path_item(table_name) # type: ignore table_name = utils.parse_js_uri_path_item(table_name) # type: ignore
# Check that the user can access the datasource # Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource( if not self.appbuilder.sm.can_access_datasource(
@ -2245,7 +2245,7 @@ class Superset(BaseSupersetView):
) )
payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack) payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack)
obj: dict = _deserialize_results_payload( obj = _deserialize_results_payload(
payload, query, cast(bool, results_backend_use_msgpack) payload, query, cast(bool, results_backend_use_msgpack)
) )
@ -2474,9 +2474,7 @@ class Superset(BaseSupersetView):
schema: str = cast(str, query_params.get("schema")) schema: str = cast(str, query_params.get("schema"))
sql: str = cast(str, query_params.get("sql")) sql: str = cast(str, query_params.get("sql"))
try: try:
template_params: dict = json.loads( template_params = json.loads(query_params.get("templateParams") or "{}")
query_params.get("templateParams") or "{}"
)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning( logger.warning(
f"Invalid template parameter {query_params.get('templateParams')}" f"Invalid template parameter {query_params.get('templateParams')}"

View File

@ -61,7 +61,7 @@ def get_col_type(col: Dict[Any, Any]) -> str:
def get_table_metadata( def get_table_metadata(
database: Database, table_name: str, schema_name: Optional[str] database: Database, table_name: str, schema_name: Optional[str]
) -> Dict: ) -> Dict[str, Any]:
""" """
Get table metadata information, including type, pk, fks. Get table metadata information, including type, pk, fks.
This function raises SQLAlchemyError when a schema is not found. This function raises SQLAlchemyError when a schema is not found.
@ -72,7 +72,7 @@ def get_table_metadata(
:param schema_name: schema name :param schema_name: schema name
:return: Dict table metadata ready for API response :return: Dict table metadata ready for API response
""" """
keys: List = [] keys = []
columns = database.get_columns(table_name, schema_name) columns = database.get_columns(table_name, schema_name)
primary_key = database.get_pk_constraint(table_name, schema_name) primary_key = database.get_pk_constraint(table_name, schema_name)
if primary_key and primary_key.get("constrained_columns"): if primary_key and primary_key.get("constrained_columns"):
@ -82,7 +82,7 @@ def get_table_metadata(
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
indexes = get_indexes_metadata(database, table_name, schema_name) indexes = get_indexes_metadata(database, table_name, schema_name)
keys += foreign_keys + indexes keys += foreign_keys + indexes
payload_columns: List[Dict] = [] payload_columns: List[Dict[str, Any]] = []
for col in columns: for col in columns:
dtype = get_col_type(col) dtype = get_col_type(col)
payload_columns.append( payload_columns.append(
@ -90,7 +90,7 @@ def get_table_metadata(
"name": col["name"], "name": col["name"],
"type": dtype.split("(")[0] if "(" in dtype else dtype, "type": dtype.split("(")[0] if "(" in dtype else dtype,
"longType": dtype, "longType": dtype,
"keys": [k for k in keys if col["name"] in k.get("column_names")], "keys": [k for k in keys if col["name"] in k["column_names"]],
} }
) )
return { return {
@ -270,7 +270,7 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
""" """
self.incr_stats("init", self.table_metadata.__name__) self.incr_stats("init", self.table_metadata.__name__)
try: try:
table_info: Dict = get_table_metadata(database, table_name, schema_name) table_info = get_table_metadata(database, table_name, schema_name)
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
self.incr_stats("error", self.table_metadata.__name__) self.incr_stats("error", self.table_metadata.__name__)
return self.response_422(error_msg_from_exception(ex)) return self.response_422(error_msg_from_exception(ex))

View File

@ -29,7 +29,7 @@ from superset.views.base_api import BaseSupersetModelRestApi
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_datasource_access(f: Callable) -> Callable: def check_datasource_access(f: Callable[..., Any]) -> Callable[..., Any]:
""" """
A Decorator that checks if a user has datasource access A Decorator that checks if a user has datasource access
""" """

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import enum import enum
from typing import Type from typing import Type, Union
import simplejson as json import simplejson as json
from croniter import croniter from croniter import croniter
@ -55,7 +55,7 @@ class EmailScheduleView(
raise NotImplementedError() raise NotImplementedError()
@property @property
def schedule_type_model(self) -> Type: def schedule_type_model(self) -> Type[Union[Dashboard, Slice]]:
raise NotImplementedError() raise NotImplementedError()
page_size = 20 page_size = 20
@ -154,9 +154,7 @@ class EmailScheduleView(
info[col] = info[col].username info[col] = info[col].username
info["user"] = schedule.user.username info["user"] = schedule.user.username
info[self.schedule_type] = getattr( # type: ignore info[self.schedule_type] = getattr(schedule, self.schedule_type).id
schedule, self.schedule_type
).id
schedules.append(info) schedules.append(info)
return json_success(json.dumps(schedules, default=json_iso_dttm_ser)) return json_success(json.dumps(schedules, default=json_iso_dttm_ser))

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Callable from typing import Any
import simplejson as json import simplejson as json
from flask import g, redirect, request, Response from flask import g, redirect, request, Response
@ -40,7 +40,7 @@ from .base import (
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query: BaseQuery, value: Callable) -> BaseQuery: def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
""" """
Filter queries to only those owned by current user. If Filter queries to only those owned by current user. If
can_access_all_queries permission is set a user can list all queries can_access_all_queries permission is set a user can list all queries

View File

@ -35,7 +35,7 @@ from superset.utils.core import QueryStatus, TimeRangeEndpoint
from superset.viz import BaseViz from superset.viz import BaseViz
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
from superset import viz_sip38 as viz # type: ignore from superset import viz_sip38 as viz
else: else:
from superset import viz # type: ignore from superset import viz # type: ignore
@ -318,9 +318,9 @@ def get_dashboard_extra_filters(
def build_extra_filters( def build_extra_filters(
layout: Dict, layout: Dict[str, Dict[str, Any]],
filter_scopes: Dict, filter_scopes: Dict[str, Dict[str, Any]],
default_filters: Dict[str, Dict[str, List]], default_filters: Dict[str, Dict[str, List[Any]]],
slice_id: int, slice_id: int,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
extra_filters = [] extra_filters = []
@ -343,7 +343,9 @@ def build_extra_filters(
return extra_filters return extra_filters
def is_slice_in_container(layout: Dict, container_id: str, slice_id: int) -> bool: def is_slice_in_container(
layout: Dict[str, Dict[str, Any]], container_id: str, slice_id: int
) -> bool:
if container_id == "ROOT_ID": if container_id == "ROOT_ID":
return True return True

View File

@ -2720,7 +2720,7 @@ class PairedTTestViz(BaseViz):
else: else:
cols.append(col) cols.append(col)
df.columns = cols df.columns = cols
data: Dict = {} data: Dict[str, List[Dict[str, Any]]] = {}
series = df.to_dict("series") series = df.to_dict("series")
for nameSet in df.columns: for nameSet in df.columns:
# If no groups are defined, nameSet will be the metric name # If no groups are defined, nameSet will be the metric name
@ -2750,7 +2750,7 @@ class RoseViz(NVD3TimeSeriesViz):
return None return None
data = super().get_data(df) data = super().get_data(df)
result: Dict = {} result: Dict[str, List[Dict[str, str]]] = {}
for datum in data: # type: ignore for datum in data: # type: ignore
key = datum["key"] key = datum["key"]
for val in datum["values"]: for val in datum["values"]:

View File

@ -18,7 +18,7 @@
"""Unit tests for Superset""" """Unit tests for Superset"""
import imp import imp
import json import json
from typing import Dict, Union, List from typing import Any, Dict, Union, List
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pandas as pd import pandas as pd
@ -397,7 +397,9 @@ class SupersetTestCase(TestCase):
mock_method.assert_called_once_with("error", func_name) mock_method.assert_called_once_with("error", func_name)
return rv return rv
def post_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response: def post_assert_metric(
self, uri: str, data: Dict[str, Any], func_name: str
) -> Response:
""" """
Simple client post with an extra assertion for statsd metrics Simple client post with an extra assertion for statsd metrics
@ -417,7 +419,9 @@ class SupersetTestCase(TestCase):
mock_method.assert_called_once_with("error", func_name) mock_method.assert_called_once_with("error", func_name)
return rv return rv
def put_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response: def put_assert_metric(
self, uri: str, data: Dict[str, Any], func_name: str
) -> Response:
""" """
Simple client put with an extra assertion for statsd metrics Simple client put with an extra assertion for statsd metrics

View File

@ -20,7 +20,7 @@ from copy import copy
from cachelib.redis import RedisCache from cachelib.redis import RedisCache
from flask import Flask from flask import Flask
from superset.config import * # type: ignore from superset.config import *
AUTH_USER_REGISTRATION_ROLE = "alpha" AUTH_USER_REGISTRATION_ROLE = "alpha"
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db") SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")