From 7f6dbf838e4e527e640a002ce20bf5da1abf4a98 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Mon, 25 May 2020 12:32:49 -0700 Subject: [PATCH] style: enforcing mypy typing for connectors (#9824) Co-authored-by: John Bodley --- setup.cfg | 4 +- superset/connectors/base/models.py | 61 ++++---- superset/connectors/base/views.py | 3 +- superset/connectors/druid/models.py | 210 +++++++++++++++++----------- superset/connectors/druid/views.py | 25 ++-- superset/connectors/sqla/models.py | 132 ++++++++++------- superset/connectors/sqla/views.py | 18 ++- superset/db_engine_specs/base.py | 2 +- superset/errors.py | 3 +- superset/jinja_context.py | 32 +++-- superset/sql_parse.py | 2 +- superset/typing.py | 12 ++ superset/utils/core.py | 25 ++-- superset/utils/import_datasource.py | 13 +- superset/views/base.py | 28 ++-- superset/viz.py | 32 +++-- superset/viz_sip38.py | 30 ++-- 17 files changed, 392 insertions(+), 240 deletions(-) diff --git a/setup.cfg b/setup.cfg index bfef8affc..4909cd641 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false @@ -53,7 +53,7 @@ order_by_type = false ignore_missing_imports = true no_implicit_optional = true -[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*] +[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*] check_untyped_defs = true disallow_untyped_calls = true disallow_untyped_defs = true diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 2b051d35e..c4e62a6e2 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -20,12 +20,12 @@ from typing import Any, Dict, Hashable, List, Optional, Type from flask_appbuilder.security.sqla.models import User from sqlalchemy import and_, Boolean, Column, Integer, String, Text from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import foreign, Query, relationship +from sqlalchemy.orm import foreign, Query, relationship, RelationshipProperty from superset.constants import NULL_STRING from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult from superset.models.slice import Slice -from superset.typing import FilterValue, FilterValues +from superset.typing import FilterValue, FilterValues, QueryObjectDict from superset.utils import core as utils METRIC_FORM_DATA_PARAMS = [ @@ -93,7 +93,7 @@ class BaseDatasource( update_from_object_fields: List[str] @declared_attr - def slices(self): + def slices(self) -> RelationshipProperty: return relationship( "Slice", primaryjoin=lambda: and_( @@ -117,7 +117,7 @@ class BaseDatasource( return sorted([c.column_name for c in self.columns], key=lambda x: x or "") @property - def columns_types(self) -> Dict: + def columns_types(self) -> Dict[str, str]: return {c.column_name: c.type for c in self.columns} @property @@ -125,7 +125,7 @@ class BaseDatasource( return "timestamp" @property - def datasource_name(self): + def datasource_name(self) -> str: raise NotImplementedError() @property @@ -143,7 +143,7 @@ class BaseDatasource( return sorted([c.column_name for c in self.columns if c.filterable]) @property - def dttm_cols(self) -> List: + def dttm_cols(self) -> List[str]: return [] @property @@ -182,7 +182,7 @@ class BaseDatasource( } @property - def select_star(self): + def select_star(self) -> Optional[str]: pass @property @@ -336,18 +336,18 @@ class BaseDatasource( values = None return values - def external_metadata(self): + def external_metadata(self) -> List[Dict[str, str]]: """Returns column information from the external system""" raise NotImplementedError() - def get_query_str(self, query_obj) -> str: + def get_query_str(self, query_obj: QueryObjectDict) -> str: """Returns a query as a string This is used to be displayed to the user so that she/he can understand what is taking place behind the scene""" raise NotImplementedError() - def query(self, query_obj) -> QueryResult: + def query(self, query_obj: QueryObjectDict) -> QueryResult: """Executes the query and returns a dataframe query_obj is a dictionary representing Superset's query interface. @@ -363,7 +363,7 @@ class BaseDatasource( raise NotImplementedError() @staticmethod - def default_query(qry) -> Query: + def default_query(qry: Query) -> Query: return qry def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]: @@ -376,8 +376,8 @@ class BaseDatasource( @staticmethod def get_fk_many_from_list( - object_list, fkmany, fkmany_class, key_attr - ): # pylint: disable=too-many-locals + object_list: List[Any], fkmany: List[Column], fkmany_class: Type, key_attr: str, + ) -> List[Column]: # pylint: disable=too-many-locals """Update ORM one-to-many list from object list Used for syncing metrics and columns using the same code""" @@ -390,8 +390,9 @@ class BaseDatasource( # sync existing fks for fk in fkmany: obj = object_dict.get(getattr(fk, key_attr)) - for attr in fkmany_class.update_from_object_fields: - setattr(fk, attr, obj.get(attr)) + if obj: + for attr in fkmany_class.update_from_object_fields: + setattr(fk, attr, obj.get(attr)) # create new fks new_fks = [] @@ -409,7 +410,7 @@ class BaseDatasource( fkmany += new_fks return fkmany - def update_from_object(self, obj) -> None: + def update_from_object(self, obj: Dict[str, Any]) -> None: """Update datasource from a data structure The UI's table editor crafts a complex data structure that @@ -426,18 +427,26 @@ class BaseDatasource( self.owners = obj.get("owners", []) # Syncing metrics - metrics = self.get_fk_many_from_list( - obj.get("metrics"), self.metrics, self.metric_class, "metric_name" + metrics = ( + self.get_fk_many_from_list( + obj["metrics"], self.metrics, self.metric_class, "metric_name" + ) + if self.metric_class and "metrics" in obj + else [] ) self.metrics = metrics # Syncing columns - self.columns = self.get_fk_many_from_list( - obj.get("columns"), self.columns, self.column_class, "column_name" + self.columns = ( + self.get_fk_many_from_list( + obj["columns"], self.columns, self.column_class, "column_name" + ) + if self.column_class and "columns" in obj + else [] ) def get_extra_cache_keys( # pylint: disable=no-self-use - self, query_obj: Dict[str, Any] # pylint: disable=unused-argument + self, query_obj: QueryObjectDict # pylint: disable=unused-argument ) -> List[Hashable]: """ If a datasource needs to provide additional keys for calculation of cache keys, those can be provided via this method @@ -474,7 +483,7 @@ class BaseColumn(AuditMixinNullable, ImportMixin): # [optional] Set this to support import/export functionality export_fields: List[Any] = [] - def __repr__(self): + def __repr__(self) -> str: return self.column_name num_types = ( @@ -505,11 +514,11 @@ class BaseColumn(AuditMixinNullable, ImportMixin): return self.type and any(map(lambda t: t in self.type.upper(), self.str_types)) @property - def expression(self): + def expression(self) -> Column: raise NotImplementedError() @property - def python_date_format(self): + def python_date_format(self) -> Column: raise NotImplementedError() @property @@ -557,11 +566,11 @@ class BaseMetric(AuditMixinNullable, ImportMixin): """ @property - def perm(self): + def perm(self) -> Optional[str]: raise NotImplementedError() @property - def expression(self): + def expression(self) -> Column: raise NotImplementedError() @property diff --git a/superset/connectors/base/views.py b/superset/connectors/base/views.py index 7695a893f..150c05c55 100644 --- a/superset/connectors/base/views.py +++ b/superset/connectors/base/views.py @@ -16,12 +16,13 @@ # under the License. from flask import Markup +from superset.connectors.base.models import BaseDatasource from superset.exceptions import SupersetException from superset.views.base import SupersetModelView class DatasourceModelView(SupersetModelView): - def pre_delete(self, item): + def pre_delete(self, item: BaseDatasource) -> None: if item.slices: raise SupersetException( Markup( diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 8b841cc34..b0a333274 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -24,7 +24,18 @@ from copy import deepcopy from datetime import datetime, timedelta from distutils.version import LooseVersion from multiprocessing.pool import ThreadPool -from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import pandas as pd import sqlalchemy as sa @@ -54,7 +65,7 @@ from superset.constants import NULL_STRING from superset.exceptions import SupersetException from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult -from superset.typing import FilterValues +from superset.typing import FilterValues, Granularity, Metric, QueryObjectDict from superset.utils import core as utils, import_datasource try: @@ -99,7 +110,7 @@ logger = logging.getLogger(__name__) try: # Postaggregator might not have been imported. class JavascriptPostAggregator(Postaggregator): - def __init__(self, name, field_names, function): + def __init__(self, name: str, field_names: List[str], function: str) -> None: self.post_aggregator = { "type": "javascript", "fieldNames": field_names, @@ -111,7 +122,7 @@ try: class CustomPostAggregator(Postaggregator): """A way to allow users to specify completely custom PostAggregators""" - def __init__(self, name, post_aggregator): + def __init__(self, name: str, post_aggregator: Dict[str, Any]) -> None: self.name = name self.post_aggregator = post_aggregator @@ -121,7 +132,7 @@ except NameError: # Function wrapper because bound methods cannot # be passed to processes -def _fetch_metadata_for(datasource): +def _fetch_metadata_for(datasource: "DruidDatasource") -> Optional[Dict[str, Any]]: return datasource.latest_metadata() @@ -155,10 +166,10 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): update_from_object_fields = export_fields export_children = ["datasources"] - def __repr__(self): + def __repr__(self) -> str: return self.verbose_name if self.verbose_name else self.cluster_name - def __html__(self): + def __html__(self) -> str: return self.__repr__() @property @@ -166,7 +177,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): return {"id": self.id, "name": self.cluster_name, "backend": "druid"} @staticmethod - def get_base_url(host, port) -> str: + def get_base_url(host: str, port: int) -> str: if not re.match("http(s)?://", host): host = "http://" + host @@ -335,7 +346,7 @@ class DruidColumn(Model, BaseColumn): update_from_object_fields = export_fields export_parent = "datasource" - def __repr__(self): + def __repr__(self) -> str: return self.column_name or str(self.id) @property @@ -380,7 +391,7 @@ class DruidColumn(Model, BaseColumn): @classmethod def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn": - def lookup_obj(lookup_column: "DruidColumn") -> Optional["DruidColumn"]: + def lookup_obj(lookup_column: DruidColumn) -> Optional[DruidColumn]: return ( db.session.query(DruidColumn) .filter( @@ -423,7 +434,7 @@ class DruidMetric(Model, BaseMetric): export_parent = "datasource" @property - def expression(self): + def expression(self) -> Column: return self.json @property @@ -558,8 +569,8 @@ class DruidDatasource(Model, BaseDatasource): obj=self ) - def update_from_object(self, obj): - return NotImplementedError() + def update_from_object(self, obj: Dict[str, Any]) -> None: + raise NotImplementedError() @property def link(self) -> Markup: @@ -594,7 +605,7 @@ class DruidDatasource(Model, BaseDatasource): "time_grains": ["now"], } - def __repr__(self): + def __repr__(self) -> str: return self.datasource_name @renders("datasource_name") @@ -634,7 +645,7 @@ class DruidDatasource(Model, BaseDatasource): db.session, i_datasource, lookup_cluster, lookup_datasource, import_time ) - def latest_metadata(self): + def latest_metadata(self) -> Optional[Dict[str, Any]]: """Returns segment metadata from the latest segment""" logger.info("Syncing datasource [{}]".format(self.datasource_name)) client = self.cluster.get_pydruid_client() @@ -686,6 +697,7 @@ class DruidDatasource(Model, BaseDatasource): logger.exception(ex) if segment_metadata: return segment_metadata[-1]["columns"] + return None def refresh_metrics(self) -> None: for col in self.columns: @@ -772,7 +784,7 @@ class DruidDatasource(Model, BaseDatasource): session.commit() @staticmethod - def time_offset(granularity: Union[str, Dict]) -> int: + def time_offset(granularity: Granularity) -> int: if granularity == "week_ending_saturday": return 6 * 24 * 3600 * 1000 # 6 days return 0 @@ -795,7 +807,7 @@ class DruidDatasource(Model, BaseDatasource): @staticmethod def granularity( period_name: str, timezone: Optional[str] = None, origin: Optional[str] = None - ) -> Union[str, Dict]: + ) -> Union[Dict[str, str], str]: if not period_name or period_name == "all": return "all" iso_8601_dict = { @@ -817,7 +829,7 @@ class DruidDatasource(Model, BaseDatasource): "year": "P1Y", } - granularity: Dict[str, Union[str, float]] = {"type": "period"} + granularity = {"type": "period"} if timezone: granularity["timeZone"] = timezone @@ -840,12 +852,12 @@ class DruidDatasource(Model, BaseDatasource): else: granularity["type"] = "duration" granularity["duration"] = ( - utils.parse_human_timedelta(period_name).total_seconds() * 1000 + utils.parse_human_timedelta(period_name).total_seconds() * 1000 # type: ignore ) return granularity @staticmethod - def get_post_agg(mconf: Dict) -> "Postaggregator": + def get_post_agg(mconf: Dict[str, Any]) -> "Postaggregator": """ For a metric specified as `postagg` returns the kind of post aggregation for pydruid. @@ -904,7 +916,13 @@ class DruidDatasource(Model, BaseDatasource): return list(set(field_names)) @staticmethod - def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dict): + def resolve_postagg( + postagg: DruidMetric, + post_aggs: Dict[str, Any], + agg_names: Set[str], + visited_postaggs: Set[str], + metrics_dict: Dict[str, DruidMetric], + ) -> None: mconf = postagg.json_obj required_fields = set( DruidDatasource.recursive_get_fields(mconf) + mconf.get("fieldNames", []) @@ -939,9 +957,7 @@ class DruidDatasource(Model, BaseDatasource): @staticmethod def metrics_and_post_aggs( - metrics: List[Union[Dict, str]], - metrics_dict: Dict[str, DruidMetric], - druid_version=None, + metrics: List[Union[Dict, str]], metrics_dict: Dict[str, DruidMetric], ) -> Tuple[OrderedDict, OrderedDict]: # Separate metrics into those that are aggregations # and those that are post aggregations @@ -998,10 +1014,17 @@ class DruidDatasource(Model, BaseDatasource): df = client.export_pandas() return df[column_name].to_list() - def get_query_str(self, query_obj, phase=1, client=None): + def get_query_str( + self, + query_obj: QueryObjectDict, + phase: int = 1, + client: Optional["PyDruid"] = None, + ) -> str: return self.run_query(client=client, phase=phase, **query_obj) - def _add_filter_from_pre_query_data(self, df: pd.DataFrame, dimensions, dim_filter): + def _add_filter_from_pre_query_data( + self, df: pd.DataFrame, dimensions: List[Any], dim_filter: "Filter" + ) -> "Filter": ret = dim_filter if not df.empty: new_filters = [] @@ -1043,7 +1066,7 @@ class DruidDatasource(Model, BaseDatasource): return ret @staticmethod - def druid_type_from_adhoc_metric(adhoc_metric: Dict) -> str: + def druid_type_from_adhoc_metric(adhoc_metric: Dict[str, Any]) -> str: column_type = adhoc_metric["column"]["type"].lower() aggregate = adhoc_metric["aggregate"].lower() @@ -1115,12 +1138,14 @@ class DruidDatasource(Model, BaseDatasource): ) @staticmethod - def _dimensions_to_values(dimensions): + def _dimensions_to_values( + dimensions: List[Union[Dict[str, str], str]] + ) -> List[Union[Dict[str, str], str]]: """ Replace dimensions specs with their `dimension` values, and ignore those without """ - values = [] + values: List[Union[Dict[str, str], str]] = [] for dimension in dimensions: if isinstance(dimension, dict): if "extractionFn" in dimension: @@ -1133,37 +1158,37 @@ class DruidDatasource(Model, BaseDatasource): return values @staticmethod - def sanitize_metric_object(metric: Dict) -> None: + def sanitize_metric_object(metric: Metric) -> None: """ Update a metric with the correct type if necessary. :param dict metric: The metric to sanitize """ if ( utils.is_adhoc_metric(metric) - and metric["column"]["type"].upper() == "FLOAT" + and metric["column"]["type"].upper() == "FLOAT" # type: ignore ): - metric["column"]["type"] = "DOUBLE" + metric["column"]["type"] = "DOUBLE" # type: ignore def run_query( # druid self, - metrics, - granularity, - from_dttm, - to_dttm, - columns=None, - groupby=None, - filter=None, - is_timeseries=True, - timeseries_limit=None, - timeseries_limit_metric=None, - row_limit=None, - inner_from_dttm=None, - inner_to_dttm=None, - orderby=None, - extras=None, - phase=2, - client=None, - order_desc=True, + metrics: List[Metric], + granularity: str, + from_dttm: datetime, + to_dttm: datetime, + columns: Optional[List[str]] = None, + groupby: Optional[List[str]] = None, + filter: Optional[List[Dict[str, Any]]] = None, + is_timeseries: Optional[bool] = True, + timeseries_limit: Optional[int] = None, + timeseries_limit_metric: Optional[Metric] = None, + row_limit: Optional[int] = None, + inner_from_dttm: Optional[datetime] = None, + inner_to_dttm: Optional[datetime] = None, + orderby: Optional[Any] = None, + extras: Optional[Dict[str, Any]] = None, + phase: int = 2, + client: Optional["PyDruid"] = None, + order_desc: bool = True, ) -> str: """Runs a query against Druid and returns a dataframe. """ @@ -1190,17 +1215,16 @@ class DruidDatasource(Model, BaseDatasource): ) < LooseVersion("0.11.0"): for metric in metrics: self.sanitize_metric_object(metric) - self.sanitize_metric_object(timeseries_limit_metric) + if timeseries_limit_metric: + self.sanitize_metric_object(timeseries_limit_metric) aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict ) # the dimensions list with dimensionSpecs expanded - - dimensions = self.get_dimensions( - columns if IS_SIP_38 else groupby, columns_dict - ) + columns_ = columns if IS_SIP_38 else groupby + dimensions = self.get_dimensions(columns_, columns_dict) if columns_ else [] extras = extras or {} qry = dict( @@ -1217,17 +1241,24 @@ class DruidDatasource(Model, BaseDatasource): if is_timeseries: qry["context"] = dict(skipEmptyBuckets=True) - filters = DruidDatasource.get_filters(filter, self.num_cols, columns_dict) + filters = ( + DruidDatasource.get_filters(filter, self.num_cols, columns_dict) + if filter + else None + ) if filters: qry["filter"] = filters - having_filters = self.get_having_filters(extras.get("having_druid")) - if having_filters: - qry["having"] = having_filters + if "having_druid" in extras: + having_filters = self.get_having_filters(extras["having_druid"]) + if having_filters: + qry["having"] = having_filters + else: + having_filters = None order_direction = "descending" if order_desc else "ascending" - if (IS_SIP_38 and not metrics and "__time" not in columns) or ( + if (IS_SIP_38 and not metrics and columns and "__time" not in columns) or ( not IS_SIP_38 and columns ): columns.append("__time") @@ -1240,7 +1271,7 @@ class DruidDatasource(Model, BaseDatasource): qry["limit"] = row_limit client.scan(**qry) elif (IS_SIP_38 and columns) or ( - not IS_SIP_38 and len(groupby) == 0 and not having_filters + not IS_SIP_38 and not groupby and not having_filters ): logger.info("Running timeseries query for no groupby values") del qry["dimensions"] @@ -1249,13 +1280,14 @@ class DruidDatasource(Model, BaseDatasource): not having_filters and order_desc and ( - (IS_SIP_38 and len(columns) == 1) - or (not IS_SIP_38 and len(groupby) == 1) + (IS_SIP_38 and columns and len(columns) == 1) + or (not IS_SIP_38 and groupby and len(groupby) == 1) ) ): dim = list(qry["dimensions"])[0] logger.info("Running two-phase topn query for dimension [{}]".format(dim)) pre_qry = deepcopy(qry) + order_by: Optional[str] = None if timeseries_limit_metric: order_by = utils.get_metric_name(timeseries_limit_metric) aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( @@ -1275,7 +1307,7 @@ class DruidDatasource(Model, BaseDatasource): pre_qry["granularity"] = "all" pre_qry["threshold"] = min(row_limit, timeseries_limit or row_limit) pre_qry["metric"] = order_by - pre_qry["dimension"] = self._dimensions_to_values(qry.get("dimensions"))[0] + pre_qry["dimension"] = self._dimensions_to_values(qry["dimensions"])[0] del pre_qry["dimensions"] client.topn(**pre_qry) @@ -1303,10 +1335,7 @@ class DruidDatasource(Model, BaseDatasource): qry["metric"] = list(qry["aggregations"].keys())[0] client.topn(**qry) logger.info("Phase 2 Complete") - elif ( - having_filters - or ((IS_SIP_38 and columns) or (not IS_SIP_38 and len(groupby))) > 0 - ): + elif having_filters or ((IS_SIP_38 and columns) or (not IS_SIP_38 and groupby)): # If grouping on multiple fields or using a having filter # we have to force a groupby query logger.info("Running groupby query for dimensions [{}]".format(dimensions)) @@ -1322,13 +1351,13 @@ class DruidDatasource(Model, BaseDatasource): set([x for x in pre_qry_dims if not isinstance(x, dict)]) ) dict_dims = [x for x in pre_qry_dims if isinstance(x, dict)] - pre_qry["dimensions"] = non_dict_dims + dict_dims + pre_qry["dimensions"] = non_dict_dims + dict_dims # type: ignore order_by = None if metrics: order_by = utils.get_metric_name(metrics[0]) else: - order_by = pre_qry_dims[0] + order_by = pre_qry_dims[0] # type: ignore if timeseries_limit_metric: order_by = utils.get_metric_name(timeseries_limit_metric) @@ -1366,7 +1395,7 @@ class DruidDatasource(Model, BaseDatasource): if df is None: df = pd.DataFrame() qry["filter"] = self._add_filter_from_pre_query_data( - df, pre_qry["dimensions"], filters + df, pre_qry["dimensions"], qry["filter"] ) qry["limit_spec"] = None if row_limit: @@ -1446,7 +1475,7 @@ class DruidDatasource(Model, BaseDatasource): time_offset = DruidDatasource.time_offset(query_obj["granularity"]) - def increment_timestamp(ts): + def increment_timestamp(ts: str) -> datetime: dt = utils.parse_human_datetime(ts).replace(tzinfo=DRUID_TZ) return dt + timedelta(milliseconds=time_offset) @@ -1458,7 +1487,17 @@ class DruidDatasource(Model, BaseDatasource): ) @staticmethod - def _create_extraction_fn(dim_spec): + def _create_extraction_fn( + dim_spec: Dict[str, Any] + ) -> Tuple[ + str, + Union[ + "MapLookupExtraction", + "RegexExtraction", + "RegisteredLookupExtraction", + "TimeFormatExtraction", + ], + ]: extraction_fn = None if dim_spec and "extractionFn" in dim_spec: col = dim_spec["dimension"] @@ -1487,7 +1526,12 @@ class DruidDatasource(Model, BaseDatasource): return (col, extraction_fn) @classmethod - def get_filters(cls, raw_filters, num_cols, columns_dict) -> "Filter": + def get_filters( + cls, + raw_filters: List[Dict[str, Any]], + num_cols: List[str], + columns_dict: Dict[str, DruidColumn], + ) -> "Filter": """Given Superset filter data structure, returns pydruid Filter(s)""" filters = None for flt in raw_filters: @@ -1641,7 +1685,9 @@ class DruidDatasource(Model, BaseDatasource): return cond - def get_having_filters(self, raw_filters: List[Dict[str, Any]]) -> "Having": + def get_having_filters( + self, raw_filters: List[Dict[str, Any]] + ) -> Optional["Having"]: filters = None reversed_op_map = { FilterOperator.NOT_EQUALS.value: FilterOperator.EQUALS.value, @@ -1673,16 +1719,18 @@ class DruidDatasource(Model, BaseDatasource): @classmethod def query_datasources_by_name( - cls, session: Session, database: Database, datasource_name: str, schema=None + cls, + session: Session, + database: Database, + datasource_name: str, + schema: Optional[str] = None, ) -> List["DruidDatasource"]: return [] - def external_metadata(self) -> List[Dict]: + def external_metadata(self) -> List[Dict[str, Any]]: self.merge_flag = True - return [ - {"name": k, "type": v.get("type")} - for k, v in self.latest_metadata().items() - ] + latest_metadata = self.latest_metadata() or {} + return [{"name": k, "type": v.get("type")} for k, v in latest_metadata.items()] sa.event.listen(DruidDatasource, "after_insert", security_manager.set_perm) diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py index 15e2d3df3..29932faaf 100644 --- a/superset/connectors/druid/views.py +++ b/superset/connectors/druid/views.py @@ -31,6 +31,7 @@ from superset import app, appbuilder, db, security_manager from superset.connectors.base.views import DatasourceModelView from superset.connectors.connector_registry import ConnectorRegistry from superset.constants import RouteMethod +from superset.typing import FlaskResponse from superset.utils import core as utils from superset.views.base import ( BaseSupersetView, @@ -106,7 +107,7 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): edit_form_extra_fields = add_form_extra_fields - def pre_update(self, col): + def pre_update(self, col: "DruidColumnInlineView") -> None: # If a dimension spec JSON is given, ensure that it is # valid JSON and that `outputName` is specified if col.dimension_spec_json: @@ -128,10 +129,10 @@ class DruidColumnInlineView(CompactCRUDMixin, SupersetModelView): ) ) - def post_update(self, col): + def post_update(self, col: "DruidColumnInlineView") -> None: col.refresh_metrics() - def post_add(self, col): + def post_add(self, col: "DruidColumnInlineView") -> None: self.post_update(col) @@ -240,13 +241,13 @@ class DruidClusterModelView(SupersetModelView, DeleteMixin, YamlExportMixin): yaml_dict_key = "databases" - def pre_add(self, cluster): + def pre_add(self, cluster: "DruidClusterModelView") -> None: security_manager.add_permission_view_menu("database_access", cluster.perm) - def pre_update(self, cluster): + def pre_update(self, cluster: "DruidClusterModelView") -> None: self.pre_add(cluster) - def _delete(self, pk): + def _delete(self, pk: int) -> None: DeleteMixin._delete(self, pk) @@ -334,7 +335,7 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin "modified": _("Modified"), } - def pre_add(self, datasource): + def pre_add(self, datasource: "DruidDatasourceModelView") -> None: with db.session.no_autoflush: query = db.session.query(models.DruidDatasource).filter( models.DruidDatasource.datasource_name == datasource.datasource_name, @@ -343,7 +344,7 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin if db.session.query(query.exists()).scalar(): raise Exception(get_datasource_exist_error_msg(datasource.full_name)) - def post_add(self, datasource): + def post_add(self, datasource: "DruidDatasourceModelView") -> None: datasource.refresh_metrics() security_manager.add_permission_view_menu( "datasource_access", datasource.get_perm() @@ -353,10 +354,10 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin, YamlExportMixin "schema_access", datasource.schema_perm ) - def post_update(self, datasource): + def post_update(self, datasource: "DruidDatasourceModelView") -> None: self.post_add(datasource) - def _delete(self, pk): + def _delete(self, pk: int) -> None: DeleteMixin._delete(self, pk) @@ -365,7 +366,7 @@ class Druid(BaseSupersetView): @has_access @expose("/refresh_datasources/") - def refresh_datasources(self, refresh_all=True): + def refresh_datasources(self, refresh_all: bool = True) -> FlaskResponse: """endpoint that refreshes druid datasources metadata""" session = db.session() DruidCluster = ConnectorRegistry.sources["druid"].cluster_class @@ -397,7 +398,7 @@ class Druid(BaseSupersetView): @has_access @expose("/scan_new_datasources/") - def scan_new_datasources(self): + def scan_new_datasources(self) -> FlaskResponse: """ Calling this endpoint will cause a scan for new datasources only and add them. diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 3e50280d1..0e0d8523a 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -54,10 +54,15 @@ from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetr from superset.constants import NULL_STRING from superset.db_engine_specs.base import TimestampExpression from superset.exceptions import DatabaseNotFound -from superset.jinja_context import ExtraCache, get_template_processor +from superset.jinja_context import ( + BaseTemplateProcessor, + ExtraCache, + get_template_processor, +) from superset.models.annotations import Annotation from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, QueryResult +from superset.typing import Metric, QueryObjectDict from superset.utils import core as utils, import_datasource config = app.config @@ -86,7 +91,7 @@ class AnnotationDatasource(BaseDatasource): cache_timeout = 0 changed_on = None - def query(self, query_obj: Dict[str, Any]) -> QueryResult: + def query(self, query_obj: QueryObjectDict) -> QueryResult: error_message = None qry = db.session.query(Annotation) qry = qry.filter(Annotation.layer_id == query_obj["filter"][0]["val"]) @@ -110,10 +115,10 @@ class AnnotationDatasource(BaseDatasource): error_message=error_message, ) - def get_query_str(self, query_obj): + def get_query_str(self, query_obj: QueryObjectDict) -> str: raise NotImplementedError() - def values_for_column(self, column_name, limit=10000): + def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: raise NotImplementedError() @@ -239,8 +244,8 @@ class TableColumn(Model, BaseColumn): return self.table.make_sqla_column_compatible(time_expr, label) @classmethod - def import_obj(cls, i_column): - def lookup_obj(lookup_column): + def import_obj(cls, i_column: "TableColumn") -> "TableColumn": + def lookup_obj(lookup_column: TableColumn) -> TableColumn: return ( db.session.query(TableColumn) .filter( @@ -343,8 +348,8 @@ class SqlMetric(Model, BaseMetric): return self.perm @classmethod - def import_obj(cls, i_metric): - def lookup_obj(lookup_metric): + def import_obj(cls, i_metric: "SqlMetric") -> "SqlMetric": + def lookup_obj(lookup_metric: SqlMetric) -> SqlMetric: return ( db.session.query(SqlMetric) .filter( @@ -442,7 +447,7 @@ class SqlaTable(Model, BaseDatasource): sqla_col._df_label_expected = label_expected return sqla_col - def __repr__(self): + def __repr__(self) -> str: return self.name @property @@ -521,14 +526,14 @@ class SqlaTable(Model, BaseDatasource): ) @property - def dttm_cols(self) -> List: + def dttm_cols(self) -> List[str]: l = [c.column_name for c in self.columns if c.is_dttm] if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property - def num_cols(self) -> List: + def num_cols(self) -> List[str]: return [c.column_name for c in self.columns if c.is_numeric] @property @@ -550,7 +555,7 @@ class SqlaTable(Model, BaseDatasource): def sql_url(self) -> str: return self.database.sql_url + "?table_name=" + str(self.table_name) - def external_metadata(self): + def external_metadata(self) -> List[Dict[str, str]]: cols = self.database.get_columns(self.table_name, schema=self.schema) for col in cols: try: @@ -567,7 +572,7 @@ class SqlaTable(Model, BaseDatasource): } @property - def select_star(self) -> str: + def select_star(self) -> Optional[str]: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star( @@ -589,7 +594,7 @@ class SqlaTable(Model, BaseDatasource): d["is_sqllab_view"] = self.is_sqllab_view return d - def values_for_column(self, column_name: str, limit: int = 10000) -> List: + def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: """Runs query against sqla to retrieve some sample values for the given column. """ @@ -626,10 +631,10 @@ class SqlaTable(Model, BaseDatasource): sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database) return sql - def get_template_processor(self, **kwargs): + def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: return get_template_processor(table=self, database=self.database, **kwargs) - def get_query_str_extended(self, query_obj: Dict[str, Any]) -> QueryStringExtended: + def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended: sqlaq = self.get_sqla_query(**query_obj) sql = self.database.compile_sqla_query(sqlaq.sqla_query) logger.info(sql) @@ -639,18 +644,20 @@ class SqlaTable(Model, BaseDatasource): labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries ) - def get_query_str(self, query_obj: Dict[str, Any]) -> str: + def get_query_str(self, query_obj: QueryObjectDict) -> str: query_str_ext = self.get_query_str_extended(query_obj) all_queries = query_str_ext.prequeries + [query_str_ext.sql] return ";\n\n".join(all_queries) + ";" - def get_sqla_table(self): + def get_sqla_table(self) -> table: tbl = table(self.table_name) if self.schema: tbl.schema = self.schema return tbl - def get_from_clause(self, template_processor=None): + def get_from_clause( + self, template_processor: Optional[BaseTemplateProcessor] = None + ) -> Union[table, TextAsFrom]: # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql @@ -687,7 +694,9 @@ class SqlaTable(Model, BaseDatasource): return self.make_sqla_column_compatible(sqla_metric, label) - def _get_sqla_row_level_filters(self, template_processor) -> List[str]: + def _get_sqla_row_level_filters( + self, template_processor: BaseTemplateProcessor + ) -> List[str]: """ Return the appropriate row level security filters for this table and the current user. @@ -702,22 +711,22 @@ class SqlaTable(Model, BaseDatasource): def get_sqla_query( # sqla self, - metrics, - granularity, - from_dttm, - to_dttm, - columns=None, - groupby=None, - filter=None, - is_timeseries=True, - timeseries_limit=15, - timeseries_limit_metric=None, - row_limit=None, - inner_from_dttm=None, - inner_to_dttm=None, - orderby=None, - extras=None, - order_desc=True, + metrics: List[Metric], + granularity: str, + from_dttm: datetime, + to_dttm: datetime, + columns: Optional[List[str]] = None, + groupby: Optional[List[str]] = None, + filter: Optional[List[Dict[str, Any]]] = None, + is_timeseries: bool = True, + timeseries_limit: int = 15, + timeseries_limit_metric: Optional[Metric] = None, + row_limit: Optional[int] = None, + inner_from_dttm: Optional[datetime] = None, + inner_to_dttm: Optional[datetime] = None, + orderby: Optional[List[Tuple[ColumnElement, bool]]] = None, + extras: Optional[Dict[str, Any]] = None, + order_desc: bool = True, ) -> SqlaQuery: """Querying any sqla table from this common interface""" template_kwargs = { @@ -765,8 +774,9 @@ class SqlaTable(Model, BaseDatasource): metrics_exprs: List[ColumnElement] = [] for m in metrics: if utils.is_adhoc_metric(m): + assert isinstance(m, dict) metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols)) - elif m in metrics_dict: + elif isinstance(m, str) and m in metrics_dict: metrics_exprs.append(metrics_dict[m].get_sqla_col()) else: raise Exception(_("Metric '%(metric)s' does not exist", metric=m)) @@ -781,7 +791,9 @@ class SqlaTable(Model, BaseDatasource): if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby): # dedup columns while preserving order - groupby = list(dict.fromkeys(columns if is_sip_38 else groupby)) + columns_ = columns if is_sip_38 else groupby + assert columns_ + groupby = list(dict.fromkeys(columns_)) select_exprs = [] for s in groupby: @@ -802,6 +814,7 @@ class SqlaTable(Model, BaseDatasource): ) metrics_exprs = [] + assert extras is not None time_range_endpoints = extras.get("time_range_endpoints") groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items()) if granularity: @@ -845,7 +858,8 @@ class SqlaTable(Model, BaseDatasource): where_clause_and = [] having_clause_and: List = [] - for flt in filter: + + for flt in filter: # type: ignore if not all([flt.get(s) for s in ["col", "op"]]): continue col = flt["col"] @@ -1029,12 +1043,20 @@ class SqlaTable(Model, BaseDatasource): prequeries=prequeries, ) - def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols): + def _get_timeseries_orderby( + self, + timeseries_limit_metric: Metric, + metrics_dict: Dict[str, SqlMetric], + cols: Dict[str, Column], + ) -> Optional[Column]: if utils.is_adhoc_metric(timeseries_limit_metric): + assert isinstance(timeseries_limit_metric, dict) ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols) - elif timeseries_limit_metric in metrics_dict: - timeseries_limit_metric = metrics_dict.get(timeseries_limit_metric) - ob = timeseries_limit_metric.get_sqla_col() + elif ( + isinstance(timeseries_limit_metric, str) + and timeseries_limit_metric in metrics_dict + ): + ob = metrics_dict[timeseries_limit_metric].get_sqla_col() else: raise Exception( _("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric) @@ -1054,7 +1076,7 @@ class SqlaTable(Model, BaseDatasource): return or_(*groups) - def query(self, query_obj: Dict[str, Any]) -> QueryResult: + def query(self, query_obj: QueryObjectDict) -> QueryResult: qry_start_dttm = datetime.now() query_str_ext = self.get_query_str_extended(query_obj) sql = query_str_ext.sql @@ -1101,7 +1123,7 @@ class SqlaTable(Model, BaseDatasource): def get_sqla_table_object(self) -> Table: return self.database.get_table(self.table_name, schema=self.schema) - def fetch_metadata(self, commit=True) -> None: + def fetch_metadata(self, commit: bool = True) -> None: """Fetches the metadata for the table and merges it in""" try: table = self.get_sqla_table_object() @@ -1166,7 +1188,9 @@ class SqlaTable(Model, BaseDatasource): db.session.commit() @classmethod - def import_obj(cls, i_datasource, import_time=None) -> int: + def import_obj( + cls, i_datasource: "SqlaTable", import_time: Optional[int] = None + ) -> int: """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. @@ -1174,7 +1198,7 @@ class SqlaTable(Model, BaseDatasource): superset instances. Audit metadata isn't copies over. """ - def lookup_sqlatable(table): + def lookup_sqlatable(table: "SqlaTable") -> "SqlaTable": return ( db.session.query(SqlaTable) .join(Database) @@ -1186,7 +1210,7 @@ class SqlaTable(Model, BaseDatasource): .first() ) - def lookup_database(table): + def lookup_database(table: SqlaTable) -> Database: try: return ( db.session.query(Database) @@ -1207,7 +1231,11 @@ class SqlaTable(Model, BaseDatasource): @classmethod def query_datasources_by_name( - cls, session: Session, database: Database, datasource_name: str, schema=None + cls, + session: Session, + database: Database, + datasource_name: str, + schema: Optional[str] = None, ) -> List["SqlaTable"]: query = ( session.query(cls) @@ -1219,10 +1247,10 @@ class SqlaTable(Model, BaseDatasource): return query.all() @staticmethod - def default_query(qry) -> Query: + def default_query(qry: Query) -> Query: return qry.filter_by(is_sqllab_view=False) - def has_extra_cache_key_calls(self, query_obj: Dict[str, Any]) -> bool: + def has_extra_cache_key_calls(self, query_obj: QueryObjectDict) -> bool: """ Detects the presence of calls to `ExtraCache` methods in items in query_obj that can be templated. If any are present, the query must be evaluated to extract @@ -1248,7 +1276,7 @@ class SqlaTable(Model, BaseDatasource): return True return False - def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: + def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]: """ The cache key of a SqlaTable needs to consider any keys added by the parent class and any keys added via `ExtraCache`. diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index 38f78d220..2a612a95e 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -18,6 +18,7 @@ """Views used by the SqlAlchemy connector""" import logging import re +from typing import List, Union from flask import flash, Markup, redirect from flask_appbuilder import CompactCRUDMixin, expose @@ -32,6 +33,7 @@ from wtforms.validators import Regexp from superset import app, db, security_manager from superset.connectors.base.views import DatasourceModelView from superset.constants import RouteMethod +from superset.typing import FlaskResponse from superset.utils import core as utils from superset.views.base import ( create_table_permissions, @@ -375,10 +377,10 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): ) } - def pre_add(self, table): + def pre_add(self, table: "TableModelView") -> None: validate_sqlatable(table) - def post_add(self, table, flash_message=True): + def post_add(self, table: "TableModelView", flash_message: bool = True) -> None: table.fetch_metadata() create_table_permissions(table) if flash_message: @@ -392,15 +394,15 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): "info", ) - def post_update(self, table): + def post_update(self, table: "TableModelView") -> None: self.post_add(table, flash_message=False) - def _delete(self, pk): + def _delete(self, pk: int) -> None: DeleteMixin._delete(self, pk) @expose("/edit/", methods=["GET", "POST"]) @has_access - def edit(self, pk): + def edit(self, pk: int) -> FlaskResponse: """Simple hack to redirect to explore view after saving""" resp = super(TableModelView, self).edit(pk) if isinstance(resp, str): @@ -410,7 +412,9 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): @action( "refresh", __("Refresh Metadata"), __("Refresh column metadata"), "fa-refresh" ) - def refresh(self, tables): + def refresh( + self, tables: Union["TableModelView", List["TableModelView"]] + ) -> FlaskResponse: if not isinstance(tables, list): tables = [tables] successes = [] @@ -439,7 +443,7 @@ class TableModelView(DatasourceModelView, DeleteMixin, YamlExportMixin): @expose("/list/") @has_access - def list(self): + def list(self) -> FlaskResponse: if not app.config["ENABLE_REACT_CRUD_VIEWS"]: return super().list() diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 1b0d9768d..a593f5900 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument -import dataclasses import hashlib import json import logging @@ -34,6 +33,7 @@ from typing import ( Union, ) +import dataclasses import pandas as pd import sqlparse from flask import g diff --git a/superset/errors.py b/superset/errors.py index 66e2e2f2b..54eb0ed63 100644 --- a/superset/errors.py +++ b/superset/errors.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=too-few-public-methods,invalid-name -from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional +from dataclasses import dataclass + class SupersetErrorType(str, Enum): """ diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 28eb6e2ed..e1a10cd31 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -17,7 +17,7 @@ """Defines the templating context for SQL Lab""" import inspect import re -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, TYPE_CHECKING from flask import g, request from jinja2.sandbox import SandboxedEnvironment @@ -26,6 +26,13 @@ from superset import jinja_base_context from superset.extensions import jinja_context_manager from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters +if TYPE_CHECKING: + from superset.connectors.sqla.models import ( # pylint: disable=unused-import + SqlaTable, + ) + from superset.models.core import Database # pylint: disable=unused-import + from superset.models.sql_lab import Query # pylint: disable=unused-import + def filter_values(column: str, default: Optional[str] = None) -> List[str]: """ Gets a values for a particular filter as a list @@ -200,12 +207,12 @@ class BaseTemplateProcessor: # pylint: disable=too-few-public-methods def __init__( self, - database=None, - query=None, - table=None, + database: Optional["Database"] = None, + query: Optional["Query"] = None, + table: Optional["SqlaTable"] = None, extra_cache_keys: Optional[List[Any]] = None, - **kwargs - ): + **kwargs: Any, + ) -> None: self.database = database self.query = query self.schema = None @@ -230,7 +237,7 @@ class BaseTemplateProcessor: # pylint: disable=too-few-public-methods self.context[self.engine] = self self.env = SandboxedEnvironment() - def process_template(self, sql: str, **kwargs) -> str: + def process_template(self, sql: str, **kwargs: Any) -> str: """Processes a sql template >>> sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" @@ -279,12 +286,14 @@ class PrestoTemplateProcessor(BaseTemplateProcessor): """ table_name, schema = self._schema_table(table_name, self.schema) - return self.database.db_engine_spec.latest_partition( + assert self.database + return self.database.db_engine_spec.latest_partition( # type: ignore table_name, schema, self.database )[1] def latest_sub_partition(self, table_name, **kwargs): table_name, schema = self._schema_table(table_name, self.schema) + assert self.database return self.database.db_engine_spec.latest_sub_partition( table_name=table_name, schema=schema, database=self.database, **kwargs ) @@ -305,7 +314,12 @@ for k in keys: template_processors[o.engine] = o -def get_template_processor(database, table=None, query=None, **kwargs): +def get_template_processor( + database: "Database", + table: Optional["SqlaTable"] = None, + query: Optional["Query"] = None, + **kwargs: Any, +) -> BaseTemplateProcessor: template_processor = template_processors.get( database.backend, BaseTemplateProcessor ) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 54289bf43..be9cf1006 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. import logging -from dataclasses import dataclass from typing import List, Optional, Set from urllib import parse import sqlparse +from dataclasses import dataclass from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace from sqlparse.utils import imt diff --git a/superset/typing.py b/superset/typing.py index b6686ecfe..f3db6ae56 100644 --- a/superset/typing.py +++ b/superset/typing.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from flask import Flask from flask_caching import Cache +from werkzeug.wrappers import Response CacheConfig = Union[Callable[[Flask], Cache], Dict[str, Any]] DbapiDescriptionRow = Tuple[ @@ -27,4 +28,15 @@ DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, . DbapiResult = List[Union[List[Any], Tuple[Any, ...]]] FilterValue = Union[float, int, str] FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]] +Granularity = Union[str, Dict[str, Union[str, float]]] +Metric = Union[Dict[str, str], str] +QueryObjectDict = Dict[str, Any] VizData = Optional[Union[List[Any], Dict[Any, Any]]] + +# Flask response. +Base = Union[bytes, str] +Status = Union[int, str] +Headers = Dict[str, Any] +FlaskResponse = Union[ + Response, Base, Tuple[Base, Status], Tuple[Base, Status, Headers], +] diff --git a/superset/utils/core.py b/superset/utils/core.py index e093d3d05..3618a28c6 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -43,10 +43,12 @@ from typing import ( Any, Callable, Dict, + Iterable, Iterator, List, NamedTuple, Optional, + Sequence, Set, Tuple, TYPE_CHECKING, @@ -79,6 +81,7 @@ from superset.exceptions import ( SupersetException, SupersetTimeoutException, ) +from superset.typing import Metric from superset.utils.dates import datetime_to_epoch, EPOCH try: @@ -101,7 +104,7 @@ JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1 try: # Having might not have been imported. class DimSelector(Having): - def __init__(self, **args): + def __init__(self, **args: Any) -> None: # Just a hack to prevent any exceptions Having.__init__(self, type="equalTo", aggregation=None, value=None) @@ -118,7 +121,7 @@ except NameError: pass -def flasher(msg, severity=None): +def flasher(msg: str, severity: str) -> None: """Flask's flash if available, logging call if not""" try: flash(msg, severity) @@ -235,7 +238,7 @@ def list_minus(l: List, minus: List) -> List: return [o for o in l if o not in minus] -def parse_human_datetime(s: Optional[str]) -> Optional[datetime]: +def parse_human_datetime(s: str) -> datetime: """ Returns ``datetime.datetime`` from human readable strings @@ -256,8 +259,6 @@ def parse_human_datetime(s: Optional[str]) -> Optional[datetime]: >>> year_ago_1 == year_ago_2 True """ - if not s: - return None try: dttm = parse(s) except Exception: @@ -564,7 +565,9 @@ def generic_find_uq_constraint_name(table, columns, insp): return uq["name"] -def get_datasource_full_name(database_name, datasource_name, schema=None): +def get_datasource_full_name( + database_name: str, datasource_name: str, schema: Optional[str] = None +) -> str: if not schema: return "[{}].[{}]".format(database_name, datasource_name) return "[{}].[{}].[{}]".format(database_name, schema, datasource_name) @@ -792,7 +795,7 @@ def get_email_address_list(address_string: str) -> List[str]: return [x.strip() for x in address_string_list if x.strip()] -def choicify(values): +def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]: """Takes an iterable and makes an iterable of tuples with it""" return [(v, v) for v in values] @@ -967,7 +970,7 @@ def get_example_database() -> "Database": return get_or_create_db("examples", db_uri) -def is_adhoc_metric(metric) -> bool: +def is_adhoc_metric(metric: Metric) -> bool: return bool( isinstance(metric, dict) and ( @@ -985,11 +988,11 @@ def is_adhoc_metric(metric) -> bool: ) -def get_metric_name(metric): - return metric["label"] if is_adhoc_metric(metric) else metric +def get_metric_name(metric: Metric) -> str: + return metric["label"] if is_adhoc_metric(metric) else metric # type: ignore -def get_metric_names(metrics): +def get_metric_names(metrics: Sequence[Metric]) -> List[str]: return [get_metric_name(metric) for metric in metrics] diff --git a/superset/utils/import_datasource.py b/superset/utils/import_datasource.py index 075dbe4ee..19f6d5985 100644 --- a/superset/utils/import_datasource.py +++ b/superset/utils/import_datasource.py @@ -15,15 +15,22 @@ # specific language governing permissions and limitations # under the License. import logging +from typing import Callable, Optional +from flask_appbuilder import Model +from sqlalchemy.orm import Session from sqlalchemy.orm.session import make_transient logger = logging.getLogger(__name__) def import_datasource( - session, i_datasource, lookup_database, lookup_datasource, import_time -): + session: Session, + i_datasource: Model, + lookup_database: Callable, + lookup_datasource: Callable, + import_time: Optional[int] = None, +) -> int: """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. @@ -75,7 +82,7 @@ def import_datasource( return datasource.id -def import_simple_obj(session, i_obj, lookup_obj): +def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model: make_transient(i_obj) i_obj.id = None i_obj.table = None diff --git a/superset/views/base.py b/superset/views/base.py index af0c3c320..6995fd182 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import dataclasses import functools import logging import traceback from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union +import dataclasses import simplejson as json import yaml from flask import abort, flash, g, get_flashed_messages, redirect, Response, session @@ -48,10 +48,16 @@ from superset.connectors.sqla import models from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException, SupersetSecurityException from superset.translations.utils import get_language_pack +from superset.typing import FlaskResponse from superset.utils import core as utils from .utils import bootstrap_user_data +if TYPE_CHECKING: + from superset.connectors.druid.views import ( # pylint: disable=unused-import + DruidClusterModelView, + ) + FRONTEND_CONF_KEYS = ( "SUPERSET_WEBSERVER_TIMEOUT", "SUPERSET_DASHBOARD_POSITION_DATA_LIMIT", @@ -305,7 +311,7 @@ class SupersetModelView(ModelView): page_size = 100 list_widget = SupersetListWidget - def render_app_template(self): + def render_app_template(self) -> FlaskResponse: payload = { "user": bootstrap_user_data(g.user), "common": common_bootstrap_payload(), @@ -359,7 +365,9 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods class DeleteMixin: # pylint: disable=too-few-public-methods - def _delete(self, primary_key): + def _delete( + self: Union[BaseView, "DeleteMixin", "DruidClusterModelView"], primary_key: int, + ) -> None: """ Delete function logic, override to implement diferent logic deletes the record with primary_key = primary_key @@ -367,11 +375,11 @@ class DeleteMixin: # pylint: disable=too-few-public-methods :param primary_key: record primary key to delete """ - item = self.datamodel.get(primary_key, self._base_filters) + item = self.datamodel.get(primary_key, self._base_filters) # type: ignore if not item: abort(404) try: - self.pre_delete(item) + self.pre_delete(item) # type: ignore except Exception as ex: # pylint: disable=broad-except flash(str(ex), "danger") else: @@ -384,8 +392,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods .all() ) - if self.datamodel.delete(item): - self.post_delete(item) + if self.datamodel.delete(item): # type: ignore + self.post_delete(item) # type: ignore for pv in pvs: security_manager.get_session.delete(pv) @@ -395,8 +403,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods security_manager.get_session.commit() - flash(*self.datamodel.message) - self.update_redirect() + flash(*self.datamodel.message) # type: ignore + self.update_redirect() # type: ignore @action( "muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False diff --git a/superset/viz.py b/superset/viz.py index c5888de6f..f20d8edb1 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -21,7 +21,6 @@ These objects represent the backend of all the visualizations that Superset can render. """ import copy -import dataclasses import hashlib import inspect import logging @@ -34,6 +33,7 @@ from datetime import datetime, timedelta from itertools import product from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING +import dataclasses import geohash import numpy as np import pandas as pd @@ -1049,9 +1049,9 @@ class BubbleViz(NVD3Viz): # dedup groupby if it happens to be the same d["groupby"] = list(dict.fromkeys(d["groupby"])) - self.x_metric = form_data.get("x") - self.y_metric = form_data.get("y") - self.z_metric = form_data.get("size") + self.x_metric = form_data["x"] + self.y_metric = form_data["y"] + self.z_metric = form_data["size"] self.entity = form_data.get("entity") self.series = form_data.get("series") or self.entity d["row_limit"] = form_data.get("limit") @@ -1093,7 +1093,7 @@ class BulletViz(NVD3Viz): def query_obj(self): form_data = self.form_data d = super().query_obj() - self.metric = form_data.get("metric") + self.metric = form_data["metric"] d["metrics"] = [self.metric] if not self.metric: @@ -1451,8 +1451,8 @@ class NVD3DualLineViz(NVD3Viz): _("Pick a time granularity for your time series") ) - metric = utils.get_metric_name(fd.get("metric")) - metric_2 = utils.get_metric_name(fd.get("metric_2")) + metric = utils.get_metric_name(fd["metric"]) + metric_2 = utils.get_metric_name(fd["metric_2"]) df = df.pivot_table(index=DTTM_ALIAS, values=[metric, metric_2]) chart_data = self.to_series(df) @@ -1507,7 +1507,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz): df = df.pivot_table( index=DTTM_ALIAS, columns="series", - values=utils.get_metric_name(fd.get("metric")), + values=utils.get_metric_name(fd["metric"]), ) chart_data = self.to_series(df) for serie in chart_data: @@ -1690,8 +1690,12 @@ class SunburstViz(BaseViz): fd = self.form_data cols = fd.get("groupby") or [] cols.extend(["m1", "m2"]) - metric = utils.get_metric_name(fd.get("metric")) - secondary_metric = utils.get_metric_name(fd.get("secondary_metric")) + metric = utils.get_metric_name(fd["metric"]) + secondary_metric = ( + utils.get_metric_name(fd["secondary_metric"]) + if "secondary_metric" in fd + else None + ) if metric == secondary_metric or secondary_metric is None: df.rename(columns={df.columns[-1]: "m1"}, inplace=True) df["m2"] = df["m1"] @@ -1872,8 +1876,12 @@ class WorldMapViz(BaseViz): fd = self.form_data cols = [fd.get("entity")] - metric = utils.get_metric_name(fd.get("metric")) - secondary_metric = utils.get_metric_name(fd.get("secondary_metric")) + metric = utils.get_metric_name(fd["metric"]) + secondary_metric = ( + utils.get_metric_name(fd["secondary_metric"]) + if "secondary_metric" in fd + else None + ) columns = ["country", "m1", "m2"] if metric == secondary_metric: ndf = df[cols] diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py index 0df660ea7..f51580d6d 100644 --- a/superset/viz_sip38.py +++ b/superset/viz_sip38.py @@ -21,7 +21,6 @@ These objects represent the backend of all the visualizations that Superset can render. """ import copy -import dataclasses import hashlib import inspect import logging @@ -34,6 +33,7 @@ from datetime import datetime, timedelta from itertools import product from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING +import dataclasses import geohash import numpy as np import pandas as pd @@ -1077,9 +1077,9 @@ class BubbleViz(NVD3Viz): form_data = self.form_data d = super().query_obj() - self.x_metric = form_data.get("x") - self.y_metric = form_data.get("y") - self.z_metric = form_data.get("size") + self.x_metric = form_data["x"] + self.y_metric = form_data["y"] + self.z_metric = form_data["size"] self.entity = form_data.get("entity") self.series = form_data.get("series") or self.entity d["row_limit"] = form_data.get("limit") @@ -1476,8 +1476,8 @@ class NVD3DualLineViz(NVD3Viz): _("Pick a time granularity for your time series") ) - metric = utils.get_metric_name(fd.get("metric")) - metric_2 = utils.get_metric_name(fd.get("metric_2")) + metric = utils.get_metric_name(fd["metric"]) + metric_2 = utils.get_metric_name(fd["metric_2"]) df = df.pivot_table(index=DTTM_ALIAS, values=[metric, metric_2]) chart_data = self.to_series(df) @@ -1532,7 +1532,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz): df = df.pivot_table( index=DTTM_ALIAS, columns="series", - values=utils.get_metric_name(fd.get("metric")), + values=utils.get_metric_name(fd["metric"]), ) chart_data = self.to_series(df) for serie in chart_data: @@ -1710,8 +1710,12 @@ class SunburstViz(BaseViz): fd = self.form_data cols = fd.get("groupby") or [] cols.extend(["m1", "m2"]) - metric = utils.get_metric_name(fd.get("metric")) - secondary_metric = utils.get_metric_name(fd.get("secondary_metric")) + metric = utils.get_metric_name(fd["metric"]) + secondary_metric = ( + utils.get_metric_name(fd["secondary_metric"]) + if "secondary_metric" in fd + else None + ) if metric == secondary_metric or secondary_metric is None: df.rename(columns={df.columns[-1]: "m1"}, inplace=True) df["m2"] = df["m1"] @@ -1868,8 +1872,12 @@ class WorldMapViz(BaseViz): fd = self.form_data cols = [fd.get("entity")] - metric = utils.get_metric_name(fd.get("metric")) - secondary_metric = utils.get_metric_name(fd.get("secondary_metric")) + metric = utils.get_metric_name(fd["metric"]) + secondary_metric = ( + utils.get_metric_name(fd["secondary_metric"]) + if "secondary_metric" in fd + else None + ) columns = ["country", "m1", "m2"] if metric == secondary_metric: ndf = df[cols]